Browse code

cmd/docker-proxy: UDP: thread-safe Write and Close

The UDP proxy used by cmd/docker-proxy is executing Write and Close in
two separate goroutines, such that a Close could interrupt an in-flight
Write.

Introduce a `connTrackEntry` that wraps a `net.Conn` and a `sync.Mutex`
to ensure that Write and Close are serialized.

Signed-off-by: Albin Kerouanton <albinker@gmail.com>

Albin Kerouanton authored on 2025/03/17 17:08:53
Showing 1 changed files
... ...
@@ -43,7 +43,25 @@ func newConnTrackKey(addr *net.UDPAddr) *connTrackKey {
43 43
 	}
44 44
 }
45 45
 
46
-type connTrackMap map[connTrackKey]*net.UDPConn
46
+type connTrackMap map[connTrackKey]*connTrackEntry
47
+
48
+// connTrackEntry wraps a UDP connection to provide thread-safe [net.Conn.Write]
49
+// and [net.Conn.Close] operations.
50
+type connTrackEntry struct {
51
+	conn *net.UDPConn
52
+	// This lock should be held before calling Write or Close on the wrapped
53
+	// net.UDPConn. Read can be called concurrently to these operations.
54
+	//
55
+	// Never lock mu without locking UDPProxy.connTrackLock first.
56
+	mu sync.Mutex
57
+}
58
+
59
+func newConnTrackEntry(conn *net.UDPConn) *connTrackEntry {
60
+	return &connTrackEntry{
61
+		conn: conn,
62
+		mu:   sync.Mutex{},
63
+	}
64
+}
47 65
 
48 66
 // UDPProxy is proxy for which handles UDP datagrams. It implements the Proxy
49 67
 // interface to handle UDP traffic forwarding between the frontend and backend
... ...
@@ -68,12 +86,13 @@ func NewUDPProxy(listener *net.UDPConn, backendAddr *net.UDPAddr, ipVer ipVersio
68 68
 	}, nil
69 69
 }
70 70
 
71
-func (proxy *UDPProxy) replyLoop(proxyConn *net.UDPConn, serverAddr net.IP, clientAddr *net.UDPAddr, clientKey *connTrackKey) {
71
+func (proxy *UDPProxy) replyLoop(cte *connTrackEntry, serverAddr net.IP, clientAddr *net.UDPAddr, clientKey *connTrackKey) {
72 72
 	defer func() {
73 73
 		proxy.connTrackLock.Lock()
74 74
 		delete(proxy.connTrackTable, *clientKey)
75
+		cte.mu.Lock()
75 76
 		proxy.connTrackLock.Unlock()
76
-		proxyConn.Close()
77
+		cte.conn.Close()
77 78
 	}()
78 79
 
79 80
 	var oob []byte
... ...
@@ -87,9 +106,9 @@ func (proxy *UDPProxy) replyLoop(proxyConn *net.UDPConn, serverAddr net.IP, clie
87 87
 
88 88
 	readBuf := make([]byte, UDPBufSize)
89 89
 	for {
90
-		proxyConn.SetReadDeadline(time.Now().Add(UDPConnTrackTimeout))
90
+		cte.conn.SetReadDeadline(time.Now().Add(UDPConnTrackTimeout))
91 91
 	again:
92
-		read, err := proxyConn.Read(readBuf)
92
+		read, err := cte.conn.Read(readBuf)
93 93
 		if err != nil {
94 94
 			if err, ok := err.(*net.OpError); ok && err.Err == syscall.ECONNREFUSED {
95 95
 				// This will happen if the last write failed
... ...
@@ -134,15 +153,16 @@ func (proxy *UDPProxy) Run() {
134 134
 
135 135
 		fromKey := newConnTrackKey(from)
136 136
 		proxy.connTrackLock.Lock()
137
-		proxyConn, hit := proxy.connTrackTable[*fromKey]
137
+		cte, hit := proxy.connTrackTable[*fromKey]
138 138
 		if !hit {
139
-			proxyConn, err = net.DialUDP("udp", nil, proxy.backendAddr)
139
+			proxyConn, err := net.DialUDP("udp", nil, proxy.backendAddr)
140 140
 			if err != nil {
141 141
 				log.Printf("Can't proxy a datagram to udp/%s: %s\n", proxy.backendAddr, err)
142 142
 				proxy.connTrackLock.Unlock()
143 143
 				continue
144 144
 			}
145
-			proxy.connTrackTable[*fromKey] = proxyConn
145
+			cte = newConnTrackEntry(proxyConn)
146
+			proxy.connTrackTable[*fromKey] = cte
146 147
 
147 148
 			daddr, err := readDestFromCmsg(oob, proxy.ipVer)
148 149
 			if err != nil {
... ...
@@ -151,17 +171,20 @@ func (proxy *UDPProxy) Run() {
151 151
 				continue
152 152
 			}
153 153
 
154
-			go proxy.replyLoop(proxyConn, daddr, from, fromKey)
154
+			go proxy.replyLoop(cte, daddr, from, fromKey)
155 155
 		}
156
+		cte.mu.Lock()
156 157
 		proxy.connTrackLock.Unlock()
158
+		cte.conn.SetWriteDeadline(time.Now().Add(UDPConnTrackTimeout))
157 159
 		for i := 0; i != read; {
158
-			written, err := proxyConn.Write(readBuf[i:read])
160
+			written, err := cte.conn.Write(readBuf[i:read])
159 161
 			if err != nil {
160 162
 				log.Printf("Can't proxy a datagram to udp/%s: %s\n", proxy.backendAddr, err)
161 163
 				break
162 164
 			}
163 165
 			i += written
164 166
 		}
167
+		cte.mu.Unlock()
165 168
 	}
166 169
 }
167 170
 
... ...
@@ -194,12 +217,15 @@ func readDestFromCmsg(oob []byte, ipVer ipVersion) (_ net.IP, err error) {
194 194
 	return cm.Dst, nil
195 195
 }
196 196
 
197
-// Close stops forwarding the traffic.
197
+// Close ungracefully stops forwarding the traffic.
198 198
 func (proxy *UDPProxy) Close() {
199 199
 	proxy.listener.Close()
200 200
 	proxy.connTrackLock.Lock()
201 201
 	defer proxy.connTrackLock.Unlock()
202
-	for _, conn := range proxy.connTrackTable {
203
-		conn.Close()
202
+	for _, cte := range proxy.connTrackTable {
203
+		// Unlike the GC logic in replyLoop, we want to close the connections
204
+		// immediately, even if there are pending and in-progress writes. So no
205
+		// need to lock cte.mu here.
206
+		cte.conn.Close()
204 207
 	}
205 208
 }