Browse code

cmd/docker-proxy: do not eagerly GC one-sided UDP conns

The UDP proxy is setting a deadline of 90 seconds when reading from the
backend. If no data is received within this interval, it reclaims the
connection.

This means, the backend would see a different connection every 90
seconds if the backend never sends back any reply to a client.

This change prevents the proxy from eagerly GC'ing such connections by
taking into account the last time a datagram was proxyed to the backend.

Signed-off-by: Albin Kerouanton <albinker@gmail.com>

Albin Kerouanton authored on 2025/03/17 22:48:48
Showing 2 changed files
... ...
@@ -5,6 +5,7 @@ import (
5 5
 	"errors"
6 6
 	"log"
7 7
 	"net"
8
+	"os"
8 9
 	"sync"
9 10
 	"syscall"
10 11
 	"time"
... ...
@@ -49,7 +50,8 @@ type connTrackMap map[connTrackKey]*connTrackEntry
49 49
 // connTrackEntry wraps a UDP connection to provide thread-safe [net.Conn.Write]
50 50
 // and [net.Conn.Close] operations.
51 51
 type connTrackEntry struct {
52
-	conn *net.UDPConn
52
+	conn  *net.UDPConn
53
+	lastW time.Time
53 54
 	// This lock should be held before calling Write or Close on the wrapped
54 55
 	// net.UDPConn. Read can be called concurrently to these operations.
55 56
 	//
... ...
@@ -64,6 +66,12 @@ func newConnTrackEntry(conn *net.UDPConn) *connTrackEntry {
64 64
 	}
65 65
 }
66 66
 
67
+func (cte *connTrackEntry) lastWrite() time.Time {
68
+	cte.mu.Lock()
69
+	defer cte.mu.Unlock()
70
+	return cte.lastW
71
+}
72
+
67 73
 // UDPProxy is proxy for which handles UDP datagrams. It implements the Proxy
68 74
 // interface to handle UDP traffic forwarding between the frontend and backend
69 75
 // addresses.
... ...
@@ -121,6 +129,15 @@ func (proxy *UDPProxy) replyLoop(cte *connTrackEntry, serverAddr net.IP, clientA
121 121
 				// expires:
122 122
 				goto again
123 123
 			}
124
+			// If the UDP connection is one-sided (i.e. the backend never sends
125
+			// replies), the connTrackEntry should not be GC'd until no writes
126
+			// happen for proxy.connTrackTimeout.
127
+			//
128
+			// Since the ReadDeadline is set to proxy.connTrackTimeout, in such
129
+			// case, the connTrackEntry will be GC'd at most after 2 * proxy.connTrackTimeout.
130
+			if errors.Is(err, os.ErrDeadlineExceeded) && time.Since(cte.lastWrite()) < proxy.connTrackTimeout {
131
+				continue
132
+			}
124 133
 			return
125 134
 		}
126 135
 		for i := 0; i != read; {
... ...
@@ -186,6 +203,7 @@ func (proxy *UDPProxy) Run() {
186 186
 				break
187 187
 			}
188 188
 			i += written
189
+			cte.lastW = time.Now()
189 190
 		}
190 191
 		cte.mu.Unlock()
191 192
 	}
192 193
new file mode 100644
... ...
@@ -0,0 +1,78 @@
0
+package main
1
+
2
+import (
3
+	"net"
4
+	"testing"
5
+	"time"
6
+
7
+	"gotest.tools/v3/assert"
8
+)
9
+
10
+// TestUDPOneSided makes sure that the conntrack entry isn't GC'd if the
11
+// backend never writes to the UDP client.
12
+func TestUDPOneSided(t *testing.T) {
13
+	frontend, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
14
+	assert.NilError(t, err)
15
+	defer frontend.Close()
16
+
17
+	backend, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
18
+	assert.NilError(t, err)
19
+	defer backend.Close()
20
+
21
+	type udpMsg struct {
22
+		data  []byte
23
+		saddr *net.UDPAddr
24
+	}
25
+	msgs := make(chan udpMsg)
26
+	go func() {
27
+		for {
28
+			buf := make([]byte, 1024)
29
+			n, saddr, err := backend.ReadFromUDP(buf)
30
+			if err != nil {
31
+				return
32
+			}
33
+			msgs <- udpMsg{data: buf[:n], saddr: saddr}
34
+		}
35
+	}()
36
+
37
+	proxy, err := NewUDPProxy(frontend, backend.LocalAddr().(*net.UDPAddr), ip4)
38
+	assert.NilError(t, err)
39
+	defer proxy.Close()
40
+
41
+	const connTrackTimeout = 1 * time.Second
42
+	proxy.connTrackTimeout = connTrackTimeout
43
+
44
+	go func() {
45
+		proxy.Run()
46
+	}()
47
+
48
+	client, err := net.DialUDP("udp", nil, frontend.LocalAddr().(*net.UDPAddr))
49
+	assert.NilError(t, err)
50
+	defer client.Close()
51
+
52
+	var expSaddr *net.UDPAddr
53
+	for i := range 15 {
54
+		_, err = client.Write([]byte("hello"))
55
+		assert.NilError(t, err)
56
+		time.Sleep(100 * time.Millisecond)
57
+
58
+		msg := <-msgs
59
+		assert.Equal(t, string(msg.data), "hello")
60
+		if i == 0 {
61
+			expSaddr = msg.saddr
62
+		} else {
63
+			assert.Equal(t, msg.saddr.Port, expSaddr.Port)
64
+		}
65
+	}
66
+
67
+	// The conntrack entry is checked every connTrackTimeout, but the latest
68
+	// write might be less than connTrackTimeout ago. So we need to wait for
69
+	// at least twice the conntrack timeout to make sure the entry is GC'd.
70
+	time.Sleep(2 * connTrackTimeout)
71
+	_, err = client.Write([]byte("hello"))
72
+	assert.NilError(t, err)
73
+
74
+	msg := <-msgs
75
+	assert.Equal(t, string(msg.data), "hello")
76
+	assert.Check(t, msg.saddr.Port != expSaddr.Port)
77
+}