Browse code

libnet/pa: OSAllocator: listen after bind

Move the listen syscall to the `OSAllocator` such that when
`RequestPortsInRange` returns, callers are guaranteed that the allocated
port isn't used by another process.

Bind and listen syscalls were previously split because listening before
inserting DNAT rules could cause connections to be accepted by the
kernel, so packets would never be forwarded to the container.

But, pulling them apart has an undesirable drawback: if another process
is racing against the Engine, and starts listening on the same port,
the conflict wouldn't be detected until OSAllocator's callers issue a
'listen' syscall. This means that callers need to implement their own
retry logic.

To overcome both drawbacks, set a cBPF socket filter on the socket
before it's bound, and let callers call `DetachSocketFilter` to remove
it. Now, callers are guaranteed that the port is free to use, and no
connections will be accepted prematurely.

For TCP / SCTP clients, this means that they'll send the first handshake
packet (e.g. SYN), but the kernel won't reply (e.g. SYN-ACK), and they
will retry until DNAT rules are configured or the socket filter is
removed.

Signed-off-by: Albin Kerouanton <albinker@gmail.com>

Albin Kerouanton authored on 2025/08/08 17:31:31
Showing 3 changed files
... ...
@@ -2,15 +2,19 @@ package portallocator
2 2
 
3 3
 import (
4 4
 	"context"
5
+	"errors"
5 6
 	"fmt"
6 7
 	"net"
7 8
 	"net/netip"
8 9
 	"os"
10
+	"runtime"
9 11
 	"syscall"
10 12
 
11 13
 	"github.com/containerd/log"
12 14
 	"github.com/ishidawataru/sctp"
13 15
 	"github.com/moby/moby/v2/daemon/libnetwork/types"
16
+	"golang.org/x/net/bpf"
17
+	"golang.org/x/sys/unix"
14 18
 )
15 19
 
16 20
 type OSAllocator struct {
... ...
@@ -27,16 +31,15 @@ func NewOSAllocator() OSAllocator {
27 27
 }
28 28
 
29 29
 // RequestPortsInRange reserves a port available in the range [portStart, portEnd]
30
-// for all the specified addrs, and then try to bind those addresses to allocate
31
-// the port from the OS. It returns the allocated port, and all the sockets
32
-// bound, or an error if the reserved port isn't available. Callers must take
33
-// care of closing the returned sockets.
30
+// for all the specified addrs, and then try to bind/listen those addresses to
31
+// allocate the port from the OS.
34 32
 //
35
-// Due to the semantic of SO_REUSEADDR, the OSAllocator can't fully determine
36
-// if a port is free when binding 0.0.0.0 or ::. If another socket is binding
37
-// the same port, but it's not listening to it yet, the bind will succeed but a
38
-// subsequent listen might fail. For this reason, RequestPortsInRange doesn't
39
-// retry on failure — it's caller's responsibility.
33
+// It returns the allocated port, and all the sockets bound, or an error if the
34
+// reserved port isn't available. These sockets have a filter set to ensure that
35
+// the kernel doesn't accept connections on these. Callers must take care of
36
+// calling DetachSocketFilter once they're ready to accept connections (e.g. after
37
+// setting up DNAT rules, and before starting the userland proxy), and they must
38
+// take care of closing the returned sockets.
40 39
 //
41 40
 // It's safe for concurrent use.
42 41
 func (pa OSAllocator) RequestPortsInRange(addrs []net.IP, proto types.Protocol, portStart, portEnd int) (_ int, _ []*os.File, retErr error) {
... ...
@@ -73,11 +76,11 @@ func (pa OSAllocator) RequestPortsInRange(addrs []net.IP, proto types.Protocol,
73 73
 		var sock *os.File
74 74
 		switch proto {
75 75
 		case types.TCP:
76
-			sock, err = bindTCPOrUDP(addrPort, syscall.SOCK_STREAM, syscall.IPPROTO_TCP)
76
+			sock, err = listenTCP(addrPort)
77 77
 		case types.UDP:
78 78
 			sock, err = bindTCPOrUDP(addrPort, syscall.SOCK_DGRAM, syscall.IPPROTO_UDP)
79 79
 		case types.SCTP:
80
-			sock, err = bindSCTP(addrPort)
80
+			sock, err = listenSCTP(addrPort)
81 81
 		default:
82 82
 			return 0, nil, fmt.Errorf("protocol %s not supported", proto)
83 83
 		}
... ...
@@ -101,6 +104,20 @@ func (pa OSAllocator) ReleasePorts(addrs []net.IP, proto types.Protocol, port in
101 101
 	}
102 102
 }
103 103
 
104
+func listenTCP(addr netip.AddrPort) (_ *os.File, retErr error) {
105
+	boundSocket, err := bindTCPOrUDP(addr, syscall.SOCK_STREAM, syscall.IPPROTO_TCP)
106
+	if err != nil {
107
+		return nil, err
108
+	}
109
+
110
+	somaxconn := -1 // silently capped to "/proc/sys/net/core/somaxconn"
111
+	if err := syscall.Listen(int(boundSocket.Fd()), somaxconn); err != nil {
112
+		return nil, fmt.Errorf("failed to listen on tcp socket: %w", err)
113
+	}
114
+
115
+	return boundSocket, nil
116
+}
117
+
104 118
 func bindTCPOrUDP(addr netip.AddrPort, typ int, proto types.Protocol) (_ *os.File, retErr error) {
105 119
 	var domain int
106 120
 	var sa syscall.Sockaddr
... ...
@@ -128,6 +145,16 @@ func bindTCPOrUDP(addr netip.AddrPort, typ int, proto types.Protocol) (_ *os.Fil
128 128
 		}
129 129
 	}
130 130
 
131
+	// We need to listen to make sure that the port is free, and no other process is racing against us to acquire this
132
+	// port. But listening means that connections could be accepted before DNAT rules are inserted, and they'd never
133
+	// reach the container. To avoid this, set a socket filter to drop all connections — TCP SYNs will be
134
+	// re-transmitted anyway. Callers must call DetachSocketFilter.
135
+	//
136
+	// Set the socket filter _before_ binding the socket to make sure that no UDP datagrams will fill the queue.
137
+	if err := setSocketFilter(sd); err != nil {
138
+		return nil, fmt.Errorf("failed to set drop packets filter for %s/%s: %w", addr, proto, err)
139
+	}
140
+
131 141
 	if domain == syscall.AF_INET6 {
132 142
 		syscall.SetsockoptInt(sd, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 1)
133 143
 	}
... ...
@@ -158,8 +185,21 @@ func bindTCPOrUDP(addr netip.AddrPort, typ int, proto types.Protocol) (_ *os.Fil
158 158
 	return boundSocket, nil
159 159
 }
160 160
 
161
-// bindSCTP is based on sctp.ListenSCTP. The socket is created and bound, but
162
-// does not start listening.
161
+// listenSCTP is based on sctp.ListenSCTP.
162
+func listenSCTP(addr netip.AddrPort) (_ *os.File, retErr error) {
163
+	boundSocket, err := bindSCTP(addr)
164
+	if err != nil {
165
+		return nil, err
166
+	}
167
+
168
+	somaxconn := -1 // silently capped to "/proc/sys/net/core/somaxconn"
169
+	if err := syscall.Listen(int(boundSocket.Fd()), somaxconn); err != nil {
170
+		return nil, fmt.Errorf("failed to listen on sctp socket: %w", err)
171
+	}
172
+
173
+	return boundSocket, nil
174
+}
175
+
163 176
 func bindSCTP(addr netip.AddrPort) (_ *os.File, retErr error) {
164 177
 	domain := syscall.AF_INET
165 178
 	if addr.Addr().Unmap().Is6() {
... ...
@@ -190,9 +230,58 @@ func bindSCTP(addr netip.AddrPort) (_ *os.File, retErr error) {
190 190
 		return nil, fmt.Errorf("failed to bind host port %s/sctp: %w", addr, err)
191 191
 	}
192 192
 
193
+	// We need to listen to make sure that the port is free, and no other process is racing against us to acquire this
194
+	// port. But listening means that connections could be accepted before DNAT rules are inserted, and they'd never
195
+	// reach the container. To avoid this, set a socket filter to drop all connections — SCTP handshake will be
196
+	// re-transmitted anyway. Callers must call DetachSocketFilter.
197
+	if err := setSocketFilter(sd); err != nil {
198
+		return nil, fmt.Errorf("failed to set drop packets filter for %s/sctp: %w", addr, err)
199
+	}
200
+
193 201
 	boundSocket := os.NewFile(uintptr(sd), "listener")
194 202
 	if boundSocket == nil {
195 203
 		return nil, fmt.Errorf("failed to convert socket %s/sctp", addr)
196 204
 	}
197 205
 	return boundSocket, nil
198 206
 }
207
+
208
+// DetachSocketFilter removes the BPF filter set during port allocation to prevent the kernel from accepting connections
209
+// before DNAT rules are inserted.
210
+func DetachSocketFilter(f *os.File) error {
211
+	return unix.SetsockoptInt(int(f.Fd()), syscall.SOL_SOCKET, syscall.SO_DETACH_FILTER, 0 /* ignored */)
212
+}
213
+
214
+// setSocketFilter sets a cBPF program on socket sd to drop all packets. To start receiving packets on this socket,
215
+// callers must call DetachSocketFilter.
216
+func setSocketFilter(sd int) error {
217
+	asm, err := bpf.Assemble([]bpf.Instruction{
218
+		// A cBPF program attached to a socket with SO_ATTACH_FILTER and
219
+		// returning 0 tells the kernel to drop all packets.
220
+		bpf.RetConstant{Val: 0x0},
221
+	})
222
+	if err != nil {
223
+		// (bpf.RetConstant).Assemble() doesn't return an error, so this should
224
+		// be unreachable code.
225
+		return fmt.Errorf("attaching socket filter: %w", err)
226
+	}
227
+	// Make sure the asm slice is not GC'd before setsockopt is called
228
+	defer runtime.KeepAlive(asm)
229
+
230
+	if len(asm) == 0 {
231
+		return errors.New("attaching socket filter: empty BPF program")
232
+	}
233
+
234
+	f := make([]unix.SockFilter, len(asm))
235
+	for i := range asm {
236
+		f[i] = unix.SockFilter{
237
+			Code: asm[i].Op,
238
+			Jt:   asm[i].Jt,
239
+			Jf:   asm[i].Jf,
240
+			K:    asm[i].K,
241
+		}
242
+	}
243
+	return unix.SetsockoptSockFprog(sd, syscall.SOL_SOCKET, syscall.SO_ATTACH_FILTER, &unix.SockFprog{
244
+		Len:    uint16(len(f)),
245
+		Filter: &f[0],
246
+	})
247
+}
... ...
@@ -1,12 +1,18 @@
1 1
 package portallocator
2 2
 
3 3
 import (
4
+	"fmt"
4 5
 	"io"
5 6
 	"net"
6 7
 	"net/netip"
7 8
 	"os"
9
+	"os/exec"
10
+	"strconv"
11
+	"strings"
12
+	"sync/atomic"
8 13
 	"syscall"
9 14
 	"testing"
15
+	"time"
10 16
 
11 17
 	"github.com/ishidawataru/sctp"
12 18
 	"github.com/moby/moby/v2/daemon/libnetwork/netutils"
... ...
@@ -228,3 +234,159 @@ func TestOnlyOneSocketBindsUDPPort(t *testing.T) {
228 228
 	assert.ErrorContains(t, err, "failed to bind host port")
229 229
 	assert.Equal(t, len(socks), 0)
230 230
 }
231
+
232
+// TestSocketBacklogEqualsSomaxconn verifies that the listen syscall made for
233
+// TCP / SCTP sockets has a backlog size equal to somaxconn.
234
+func TestSocketBacklogEqualsSomaxconn(t *testing.T) {
235
+	// Retrieve and parse sysctl net.core.somaxconn
236
+	somaxconnSysctl, err := os.ReadFile("/proc/sys/net/core/somaxconn")
237
+	assert.NilError(t, err)
238
+	somaxconn, err := strconv.Atoi(strings.TrimSpace(string(somaxconnSysctl)))
239
+	assert.NilError(t, err)
240
+
241
+	// UDP isn't included in the list of protos to test because it doesn't have a backlog, and the ss Send-Q column
242
+	// reports memory allocation instead of the socket's max backlog size (unlike TCP and SCTP).
243
+	//
244
+	// This is where the kernel writes the max backlog size into the sk struct: https://elixir.bootlin.com/linux/v6.16/source/net/ipv4/af_inet.c#L199
245
+	//
246
+	// And here's where the kernel writes the 'idiag_wqueue' field used by ss:
247
+	//
248
+	// - For TCP: https://elixir.bootlin.com/linux/v6.16/source/net/ipv4/tcp_diag.c#L25
249
+	// - For UDP: https://elixir.bootlin.com/linux/v6.16/source/net/ipv4/udp_diag.c#L163
250
+	// - For SCTP: https://elixir.bootlin.com/linux/v6.16/source/net/sctp/diag.c#L414
251
+	for _, proto := range []types.Protocol{
252
+		types.TCP,
253
+		types.SCTP,
254
+	} {
255
+		t.Run(proto.String(), func(t *testing.T) {
256
+			// Allocate an ephemeral port using the OSAllocator.
257
+			alloc := NewOSAllocator()
258
+			port, socks, err := alloc.RequestPortsInRange([]net.IP{net.IPv4zero}, proto, 0, 0)
259
+			assert.NilError(t, err)
260
+			defer closeSocks(t, socks)
261
+
262
+			// 'ss' output looks like that:
263
+			//
264
+			//    Netid      State       Recv-Q      Send-Q           Local Address:Port            Peer Address:Port      Process
265
+			//    tcp        LISTEN      0           4096                   0.0.0.0:32768                0.0.0.0:*
266
+			//
267
+			// The max backlog size ('idiag_wqueue' field of 'struct inet_diag_msg' in the kernel) is the 4th field in
268
+			// the output.
269
+			out, err := exec.Command("ss", "-Stl", "sport", "=", fmt.Sprintf("inet:%d", port)).Output()
270
+			assert.NilError(t, err)
271
+
272
+			t.Logf("ss output:\n" + string(out))
273
+
274
+			lines := strings.Split(string(out), "\n")
275
+			assert.Assert(t, len(lines) >= 2)
276
+
277
+			fields := strings.Fields(lines[1])
278
+			assert.Equal(t, len(fields), 6)
279
+
280
+			backlog, err := strconv.Atoi(fields[3])
281
+			assert.NilError(t, err)
282
+
283
+			assert.Equal(t, fields[4], "0.0.0.0:"+strconv.Itoa(port))
284
+			assert.Equal(t, backlog, somaxconn, "socket backlog should be equal to net.core.somaxconn")
285
+		})
286
+	}
287
+}
288
+
289
+// TestPacketsAreDroppedUntilDetachSocketFilter tests that SYN packets are
290
+// dropped until DetachSocketFilter is called on the socket.
291
+func TestPacketsAreDroppedUntilDetachSocketFilter(t *testing.T) {
292
+	const port = 61100
293
+	addr := net.ParseIP("127.0.0.1")
294
+
295
+	var detached atomic.Bool
296
+	dialCh, readCh := make(chan error), make(chan error)
297
+
298
+	alloc := NewOSAllocator()
299
+	_, socks, err := alloc.RequestPortsInRange([]net.IP{addr}, types.TCP, port, port)
300
+	assert.NilError(t, err)
301
+	assert.Check(t, len(socks) > 0)
302
+
303
+	// Start a goroutine that attempts to connect to a listening socket. It'll send SYN packets until
304
+	// DetachSocketFilter is called. If no filter is attached, the connection will succeed immediately, and it'll send
305
+	// a payload of 0x0 (or the call to DetachSocketFilter will fail with an error). When the filter is detached, it'll
306
+	// send a payload of 0x1, which will be read by the other goroutine.
307
+	go func() {
308
+		defer close(dialCh)
309
+
310
+		c, err := net.Dial("tcp", net.JoinHostPort(addr.String(), strconv.Itoa(port)))
311
+		if err != nil {
312
+			dialCh <- fmt.Errorf("net.Dial: %w", err)
313
+			return
314
+		}
315
+		defer c.Close()
316
+
317
+		payload := []byte{0x0}
318
+		if detached.Load() {
319
+			payload = []byte{0x1}
320
+		}
321
+
322
+		n, err := c.Write(payload)
323
+		if err != nil {
324
+			dialCh <- fmt.Errorf("c.Write: %w", err)
325
+			return
326
+		}
327
+		if n != len(payload) {
328
+			dialCh <- fmt.Errorf("expected to write %d bytes, but wrote %d", len(payload), n)
329
+		}
330
+	}()
331
+
332
+	// Start a goroutine that accepts a connection on the listening socket created by RequestPortsInRange, and reads
333
+	// the payload sent by the 1st goroutine. It should not receive any new connection until DetachSocketFilter is
334
+	// called on the socket.
335
+	go func() {
336
+		defer close(readCh)
337
+
338
+		// net.FileListener dup's the fd, so DetachSocketFilter will have no effect. Use raw syscalls instead.
339
+		sd := int(socks[0].Fd())
340
+
341
+		var err error
342
+		connfd, _, err := syscall.Accept(sd)
343
+		if err != nil {
344
+			readCh <- fmt.Errorf("syscall.Accept: %w", err)
345
+			return
346
+		}
347
+
348
+		payload := make([]byte, 1)
349
+		n, err := syscall.Read(connfd, payload)
350
+		if err != nil {
351
+			readCh <- fmt.Errorf("c.Read: %w", err)
352
+			return
353
+		}
354
+		if n != 1 {
355
+			readCh <- fmt.Errorf("expected to read 1 byte, but read %d", n)
356
+			return
357
+		}
358
+
359
+		if payload[0] != 0x1 {
360
+			readCh <- fmt.Errorf("expected payload 0x1, but got %x", payload[0])
361
+		}
362
+	}()
363
+
364
+	// Sleep for a bit to make sure that both goroutines were scheduled.
365
+	time.Sleep(500 * time.Millisecond)
366
+
367
+	detached.Store(true)
368
+	err = DetachSocketFilter(socks[0])
369
+	assert.NilError(t, err)
370
+
371
+	var dialStopped, readStopped bool
372
+	for {
373
+		if dialStopped && readStopped {
374
+			return
375
+		}
376
+
377
+		select {
378
+		case err, ok := <-dialCh:
379
+			dialStopped = !ok
380
+			assert.NilError(t, err)
381
+		case err, ok := <-readCh:
382
+			readStopped = !ok
383
+			assert.NilError(t, err)
384
+		}
385
+	}
386
+}
... ...
@@ -8,7 +8,6 @@ import (
8 8
 	"net/netip"
9 9
 	"os"
10 10
 	"strconv"
11
-	"syscall"
12 11
 
13 12
 	"github.com/containerd/log"
14 13
 	"github.com/moby/moby/v2/daemon/libnetwork/internal/rlkclient"
... ...
@@ -105,6 +104,9 @@ func (pm PortMapper) MapPorts(ctx context.Context, cfg []portmapperapi.PortBindi
105 105
 			if bindings[i].BoundSocket == nil || bindings[i].RootlesskitUnsupported || bindings[i].StopProxy != nil {
106 106
 				continue
107 107
 			}
108
+			if err := portallocator.DetachSocketFilter(bindings[i].BoundSocket); err != nil {
109
+				return nil, fmt.Errorf("failed to detach socket filter for port mapping %s: %w", bindings[i].PortBinding, err)
110
+			}
108 111
 			var err error
109 112
 			bindings[i].StopProxy, err = pm.startProxy(
110 113
 				bindings[i].ChildPortBinding(), bindings[i].BoundSocket,
... ...
@@ -226,17 +228,6 @@ func (pm PortMapper) attemptBindHostPorts(
226 226
 	if err := fwn.AddPorts(ctx, mergeChildHostIPs(res)); err != nil {
227 227
 		return nil, err
228 228
 	}
229
-	// Now the firewall rules are set up, it's safe to listen on the socket. (Listening
230
-	// earlier could result in dropped connections if the proxy becomes unreachable due
231
-	// to NAT rules sending packets directly to the container.)
232
-	//
233
-	// If not starting the proxy, nothing will ever accept a connection on the
234
-	// socket. Listen here anyway because SO_REUSEADDR is set, so bind() won't notice
235
-	// the problem if a port's bound to both INADDR_ANY and a specific address. (Also
236
-	// so the binding shows up in "netstat -at".)
237
-	if err := listenBoundPorts(res, pm.enableProxy); err != nil {
238
-		return nil, err
239
-	}
240 229
 	return res, nil
241 230
 }
242 231
 
... ...
@@ -297,29 +288,3 @@ func configPortDriver(ctx context.Context, pbs []portmapperapi.PortBinding, pdc
297 297
 	}
298 298
 	return nil
299 299
 }
300
-
301
-func listenBoundPorts(pbs []portmapperapi.PortBinding, proxyEnabled bool) error {
302
-	for i := range pbs {
303
-		if pbs[i].BoundSocket == nil || pbs[i].RootlesskitUnsupported || pbs[i].Proto == types.UDP {
304
-			continue
305
-		}
306
-		rc, err := pbs[i].BoundSocket.SyscallConn()
307
-		if err != nil {
308
-			return fmt.Errorf("raw conn not available on %d socket: %w", pbs[i].Proto, err)
309
-		}
310
-		if errC := rc.Control(func(fd uintptr) {
311
-			somaxconn := 0
312
-			// SCTP sockets do not support somaxconn=0
313
-			if proxyEnabled || pbs[i].Proto == types.SCTP {
314
-				somaxconn = -1 // silently capped to "/proc/sys/net/core/somaxconn"
315
-			}
316
-			err = syscall.Listen(int(fd), somaxconn)
317
-		}); errC != nil {
318
-			return fmt.Errorf("failed to Control %s socket: %w", pbs[i].Proto, err)
319
-		}
320
-		if err != nil {
321
-			return fmt.Errorf("failed to listen on %s socket: %w", pbs[i].Proto, err)
322
-		}
323
-	}
324
-	return nil
325
-}