Browse code

Set up bridge-specific iptables rules in the bridge driver

Use the bridge driver's iptables types to set up portmapping related
iptables rules - instead of using iptables.Forward, which is bridge
specific code in the iptables package.

Remove iptables.Forward() and its unit test, the bridge driver's
version is covered by TestAddPortMappings.

Remove hairpinMode from iptables.ChainInfo hairpinMode relates to bridge
driver specific behaviour, that is now implemented in the bridge driver.

Signed-off-by: Rob Murray <rob.murray@docker.com>

Rob Murray authored on 2024/05/28 23:57:56
Showing 7 changed files
... ...
@@ -329,6 +329,21 @@ func (n *bridgeNetwork) registerIptCleanFunc(clean iptableCleanFunc) {
329 329
 	n.iptCleanFuncs = append(n.iptCleanFuncs, clean)
330 330
 }
331 331
 
332
+func (n *bridgeNetwork) iptablesEnabled(version iptables.IPVersion) (bool, error) {
333
+	n.Lock()
334
+	defer n.Unlock()
335
+	if n.driver == nil {
336
+		return false, types.InvalidParameterErrorf("no driver found")
337
+	}
338
+
339
+	n.driver.Lock()
340
+	defer n.driver.Unlock()
341
+	if version == iptables.IPv6 {
342
+		return n.driver.config.EnableIP6Tables, nil
343
+	}
344
+	return n.driver.config.EnableIPTables, nil
345
+}
346
+
332 347
 func (n *bridgeNetwork) getDriverChains(version iptables.IPVersion) (*iptables.ChainInfo, *iptables.ChainInfo, *iptables.ChainInfo, *iptables.ChainInfo, error) {
333 348
 	n.Lock()
334 349
 	defer n.Unlock()
... ...
@@ -5,6 +5,7 @@ import (
5 5
 	"errors"
6 6
 	"fmt"
7 7
 	"net"
8
+	"strconv"
8 9
 
9 10
 	"github.com/containerd/log"
10 11
 	"github.com/docker/docker/libnetwork/iptables"
... ...
@@ -291,24 +292,96 @@ func (n *bridgeNetwork) setPerPortIptables(b portBinding, enable bool) error {
291 291
 		v = iptables.IPv6
292 292
 	}
293 293
 
294
-	natChain, _, _, _, err := n.getDriverChains(v)
295
-	if err != nil || natChain == nil {
294
+	if enabled, err := n.iptablesEnabled(v); err != nil || !enabled {
296 295
 		// Nothing to do, iptables/ip6tables is not enabled.
297 296
 		return nil
298 297
 	}
299
-	action := iptables.Delete
300
-	if enable {
301
-		action = iptables.Insert
302
-	}
303
-	return natChain.Forward(
304
-		action,
305
-		b.HostIP,
306
-		int(b.HostPort),
307
-		b.Proto.String(),
308
-		b.IP.String(),
309
-		int(b.Port),
310
-		n.getNetworkBridgeName(),
311
-	)
298
+
299
+	bridgeName := n.getNetworkBridgeName()
300
+	proxyPath := n.userlandProxyPath()
301
+	if err := setPerPortNAT(b, v, proxyPath, bridgeName, enable); err != nil {
302
+		return err
303
+	}
304
+	if err := setPerPortForwarding(b, v, bridgeName, enable); err != nil {
305
+		return err
306
+	}
307
+	return nil
308
+}
309
+
310
+func setPerPortNAT(b portBinding, ipv iptables.IPVersion, proxyPath string, bridgeName string, enable bool) error {
311
+	// iptables interprets "0.0.0.0" as "0.0.0.0/32", whereas we
312
+	// want "0.0.0.0/0". "0/0" is correctly interpreted as "any
313
+	// value" by both iptables and ip6tables.
314
+	hostIP := "0/0"
315
+	if !b.HostIP.IsUnspecified() {
316
+		hostIP = b.HostIP.String()
317
+	}
318
+	args := []string{
319
+		"-p", b.Proto.String(),
320
+		"-d", hostIP,
321
+		"--dport", strconv.Itoa(int(b.HostPort)),
322
+		"-j", "DNAT",
323
+		"--to-destination", net.JoinHostPort(b.IP.String(), strconv.Itoa(int(b.Port))),
324
+	}
325
+	hairpinMode := proxyPath == ""
326
+	if !hairpinMode {
327
+		args = append(args, "!", "-i", bridgeName)
328
+	}
329
+	rule := iptRule{ipv: ipv, table: iptables.Nat, chain: DockerChain, args: args}
330
+	if err := programChainRule(rule, "DNAT", enable); err != nil {
331
+		return err
332
+	}
333
+
334
+	args = []string{
335
+		"-p", b.Proto.String(),
336
+		"-s", b.IP.String(),
337
+		"-d", b.IP.String(),
338
+		"--dport", strconv.Itoa(int(b.Port)),
339
+		"-j", "MASQUERADE",
340
+	}
341
+	rule = iptRule{ipv: ipv, table: iptables.Nat, chain: "POSTROUTING", args: args}
342
+	if err := programChainRule(rule, "MASQUERADE", enable); err != nil {
343
+		return err
344
+	}
345
+
346
+	return nil
347
+}
348
+
349
+func setPerPortForwarding(b portBinding, ipv iptables.IPVersion, bridgeName string, enable bool) error {
350
+	args := []string{
351
+		"!", "-i", bridgeName,
352
+		"-o", bridgeName,
353
+		"-p", b.Proto.String(),
354
+		"-d", b.IP.String(),
355
+		"--dport", strconv.Itoa(int(b.Port)),
356
+		"-j", "ACCEPT",
357
+	}
358
+	rule := iptRule{ipv: ipv, table: iptables.Filter, chain: DockerChain, args: args}
359
+	if err := programChainRule(rule, "MASQUERADE", enable); err != nil {
360
+		return err
361
+	}
362
+
363
+	if b.Proto == types.SCTP {
364
+		// Linux kernel v4.9 and below enables NETIF_F_SCTP_CRC for veth by
365
+		// the following commit.
366
+		// This introduces a problem when combined with a physical NIC without
367
+		// NETIF_F_SCTP_CRC. As for a workaround, here we add an iptables entry
368
+		// to fill the checksum.
369
+		//
370
+		// https://github.com/torvalds/linux/commit/c80fafbbb59ef9924962f83aac85531039395b18
371
+		args = []string{
372
+			"-p", b.Proto.String(),
373
+			"--sport", strconv.Itoa(int(b.Port)),
374
+			"-j", "CHECKSUM",
375
+			"--checksum-fill",
376
+		}
377
+		rule := iptRule{ipv: ipv, table: iptables.Mangle, chain: "POSTROUTING", args: args}
378
+		if err := programChainRule(rule, "MASQUERADE", enable); err != nil {
379
+			return err
380
+		}
381
+	}
382
+
383
+	return nil
312 384
 }
313 385
 
314 386
 func (n *bridgeNetwork) reapplyPerPortIptables4() {
... ...
@@ -41,11 +41,9 @@ func setupIPChains(config configuration, version iptables.IPVersion) (natChain *
41 41
 		return nil, nil, nil, nil, errors.New("cannot create new chains, ip6tables is disabled")
42 42
 	}
43 43
 
44
-	hairpinMode := !config.EnableUserlandProxy
45
-
46 44
 	iptable := iptables.GetIptable(version)
47 45
 
48
-	natChain, err := iptable.NewChain(DockerChain, iptables.Nat, hairpinMode)
46
+	natChain, err := iptable.NewChain(DockerChain, iptables.Nat)
49 47
 	if err != nil {
50 48
 		return nil, nil, nil, nil, fmt.Errorf("failed to create NAT chain %s: %v", DockerChain, err)
51 49
 	}
... ...
@@ -57,7 +55,7 @@ func setupIPChains(config configuration, version iptables.IPVersion) (natChain *
57 57
 		}
58 58
 	}()
59 59
 
60
-	filterChain, err = iptable.NewChain(DockerChain, iptables.Filter, false)
60
+	filterChain, err = iptable.NewChain(DockerChain, iptables.Filter)
61 61
 	if err != nil {
62 62
 		return nil, nil, nil, nil, fmt.Errorf("failed to create FILTER chain %s: %v", DockerChain, err)
63 63
 	}
... ...
@@ -69,7 +67,7 @@ func setupIPChains(config configuration, version iptables.IPVersion) (natChain *
69 69
 		}
70 70
 	}()
71 71
 
72
-	isolationChain1, err = iptable.NewChain(IsolationChain1, iptables.Filter, false)
72
+	isolationChain1, err = iptable.NewChain(IsolationChain1, iptables.Filter)
73 73
 	if err != nil {
74 74
 		return nil, nil, nil, nil, fmt.Errorf("failed to create FILTER isolation chain: %v", err)
75 75
 	}
... ...
@@ -81,7 +79,7 @@ func setupIPChains(config configuration, version iptables.IPVersion) (natChain *
81 81
 		}
82 82
 	}()
83 83
 
84
-	isolationChain2, err = iptable.NewChain(IsolationChain2, iptables.Filter, false)
84
+	isolationChain2, err = iptable.NewChain(IsolationChain2, iptables.Filter)
85 85
 	if err != nil {
86 86
 		return nil, nil, nil, nil, fmt.Errorf("failed to create FILTER isolation chain: %v", err)
87 87
 	}
... ...
@@ -41,7 +41,7 @@ func arrangeUserFilterRule() {
41 41
 // that are beyond the daemon's control.
42 42
 func setupUserChain(ipVersion iptables.IPVersion) error {
43 43
 	ipt := iptables.GetIptable(ipVersion)
44
-	if _, err := ipt.NewChain(userChain, iptables.Filter, false); err != nil {
44
+	if _, err := ipt.NewChain(userChain, iptables.Filter); err != nil {
45 45
 		return fmt.Errorf("failed to create %s %v chain: %v", userChain, ipVersion, err)
46 46
 	}
47 47
 	if err := ipt.AddReturnRule(userChain); err != nil {
... ...
@@ -34,7 +34,7 @@ func TestFirewalldInit(t *testing.T) {
34 34
 
35 35
 func TestReloaded(t *testing.T) {
36 36
 	iptable := GetIptable(IPv4)
37
-	fwdChain, err := iptable.NewChain("FWD", Filter, false)
37
+	fwdChain, err := iptable.NewChain("FWD", Filter)
38 38
 	if err != nil {
39 39
 		t.Fatal(err)
40 40
 	}
... ...
@@ -78,10 +78,9 @@ type IPTable struct {
78 78
 
79 79
 // ChainInfo defines the iptables chain.
80 80
 type ChainInfo struct {
81
-	Name        string
82
-	Table       Table
83
-	HairpinMode bool
84
-	IPVersion   IPVersion
81
+	Name      string
82
+	Table     Table
83
+	IPVersion IPVersion
85 84
 }
86 85
 
87 86
 // ChainError is returned to represent errors during ip table operation.
... ...
@@ -173,7 +172,7 @@ func GetIptable(version IPVersion) *IPTable {
173 173
 }
174 174
 
175 175
 // NewChain adds a new chain to ip table.
176
-func (iptable IPTable) NewChain(name string, table Table, hairpinMode bool) (*ChainInfo, error) {
176
+func (iptable IPTable) NewChain(name string, table Table) (*ChainInfo, error) {
177 177
 	if name == "" {
178 178
 		return nil, fmt.Errorf("could not create chain: chain name is empty")
179 179
 	}
... ...
@@ -189,10 +188,9 @@ func (iptable IPTable) NewChain(name string, table Table, hairpinMode bool) (*Ch
189 189
 		}
190 190
 	}
191 191
 	return &ChainInfo{
192
-		Name:        name,
193
-		Table:       table,
194
-		HairpinMode: hairpinMode,
195
-		IPVersion:   iptable.ipVersion,
192
+		Name:      name,
193
+		Table:     table,
194
+		IPVersion: iptable.ipVersion,
196 195
 	}, nil
197 196
 }
198 197
 
... ...
@@ -310,78 +308,6 @@ func (iptable IPTable) RemoveExistingChain(name string, table Table) error {
310 310
 	return c.Remove()
311 311
 }
312 312
 
313
-// Forward adds forwarding rule to 'filter' table and corresponding nat rule to 'nat' table.
314
-func (c *ChainInfo) Forward(action Action, ip net.IP, port int, proto, destAddr string, destPort int, bridgeName string) error {
315
-	iptable := GetIptable(c.IPVersion)
316
-	daddr := ip.String()
317
-	if ip.IsUnspecified() {
318
-		// iptables interprets "0.0.0.0" as "0.0.0.0/32", whereas we
319
-		// want "0.0.0.0/0". "0/0" is correctly interpreted as "any
320
-		// value" by both iptables and ip6tables.
321
-		daddr = "0/0"
322
-	}
323
-
324
-	args := []string{
325
-		"-p", proto,
326
-		"-d", daddr,
327
-		"--dport", strconv.Itoa(port),
328
-		"-j", "DNAT",
329
-		"--to-destination", net.JoinHostPort(destAddr, strconv.Itoa(destPort)),
330
-	}
331
-
332
-	if !c.HairpinMode {
333
-		args = append(args, "!", "-i", bridgeName)
334
-	}
335
-	if err := iptable.ProgramRule(Nat, c.Name, action, args); err != nil {
336
-		return err
337
-	}
338
-
339
-	args = []string{
340
-		"!", "-i", bridgeName,
341
-		"-o", bridgeName,
342
-		"-p", proto,
343
-		"-d", destAddr,
344
-		"--dport", strconv.Itoa(destPort),
345
-		"-j", "ACCEPT",
346
-	}
347
-	if err := iptable.ProgramRule(Filter, c.Name, action, args); err != nil {
348
-		return err
349
-	}
350
-
351
-	args = []string{
352
-		"-p", proto,
353
-		"-s", destAddr,
354
-		"-d", destAddr,
355
-		"--dport", strconv.Itoa(destPort),
356
-		"-j", "MASQUERADE",
357
-	}
358
-
359
-	if err := iptable.ProgramRule(Nat, "POSTROUTING", action, args); err != nil {
360
-		return err
361
-	}
362
-
363
-	if proto == "sctp" {
364
-		// Linux kernel v4.9 and below enables NETIF_F_SCTP_CRC for veth by
365
-		// the following commit.
366
-		// This introduces a problem when combined with a physical NIC without
367
-		// NETIF_F_SCTP_CRC. As for a workaround, here we add an iptables entry
368
-		// to fill the checksum.
369
-		//
370
-		// https://github.com/torvalds/linux/commit/c80fafbbb59ef9924962f83aac85531039395b18
371
-		args = []string{
372
-			"-p", proto,
373
-			"--sport", strconv.Itoa(destPort),
374
-			"-j", "CHECKSUM",
375
-			"--checksum-fill",
376
-		}
377
-		if err := iptable.ProgramRule(Mangle, "POSTROUTING", action, args); err != nil {
378
-			return err
379
-		}
380
-	}
381
-
382
-	return nil
383
-}
384
-
385 313
 // Link adds reciprocal ACCEPT rule for two supplied IP addresses.
386 314
 // Traffic is allowed from ip1 to ip2 and vice-versa
387 315
 func (c *ChainInfo) Link(action Action, ip1, ip2 net.IP, port int, proto string, bridgeName string) error {
... ...
@@ -21,7 +21,7 @@ func createNewChain(t *testing.T) (*IPTable, *ChainInfo, *ChainInfo) {
21 21
 	t.Helper()
22 22
 	iptable := GetIptable(IPv4)
23 23
 
24
-	natChain, err := iptable.NewChain(chainName, Nat, false)
24
+	natChain, err := iptable.NewChain(chainName, Nat)
25 25
 	if err != nil {
26 26
 		t.Fatal(err)
27 27
 	}
... ...
@@ -30,7 +30,7 @@ func createNewChain(t *testing.T) (*IPTable, *ChainInfo, *ChainInfo) {
30 30
 		t.Fatal(err)
31 31
 	}
32 32
 
33
-	filterChain, err := iptable.NewChain(chainName, Filter, false)
33
+	filterChain, err := iptable.NewChain(chainName, Filter)
34 34
 	if err != nil {
35 35
 		t.Fatal(err)
36 36
 	}
... ...
@@ -46,59 +46,6 @@ func TestNewChain(t *testing.T) {
46 46
 	createNewChain(t)
47 47
 }
48 48
 
49
-func TestForward(t *testing.T) {
50
-	iptable, natChain, filterChain := createNewChain(t)
51
-
52
-	ip := net.ParseIP("192.168.1.1")
53
-	port := 1234
54
-	dstAddr := "172.17.0.1"
55
-	dstPort := 4321
56
-	proto := "tcp"
57
-
58
-	err := natChain.Forward(Insert, ip, port, proto, dstAddr, dstPort, bridgeName)
59
-	if err != nil {
60
-		t.Fatal(err)
61
-	}
62
-
63
-	dnatRule := []string{
64
-		"-d", ip.String(),
65
-		"-p", proto,
66
-		"--dport", strconv.Itoa(port),
67
-		"-j", "DNAT",
68
-		"--to-destination", dstAddr + ":" + strconv.Itoa(dstPort),
69
-		"!", "-i", bridgeName,
70
-	}
71
-
72
-	if !iptable.Exists(natChain.Table, natChain.Name, dnatRule...) {
73
-		t.Fatal("DNAT rule does not exist")
74
-	}
75
-
76
-	filterRule := []string{
77
-		"!", "-i", bridgeName,
78
-		"-o", bridgeName,
79
-		"-d", dstAddr,
80
-		"-p", proto,
81
-		"--dport", strconv.Itoa(dstPort),
82
-		"-j", "ACCEPT",
83
-	}
84
-
85
-	if !iptable.Exists(filterChain.Table, filterChain.Name, filterRule...) {
86
-		t.Fatal("filter rule does not exist")
87
-	}
88
-
89
-	masqRule := []string{
90
-		"-d", dstAddr,
91
-		"-s", dstAddr,
92
-		"-p", proto,
93
-		"--dport", strconv.Itoa(dstPort),
94
-		"-j", "MASQUERADE",
95
-	}
96
-
97
-	if !iptable.Exists(natChain.Table, "POSTROUTING", masqRule...) {
98
-		t.Fatal("MASQUERADE rule does not exist")
99
-	}
100
-}
101
-
102 49
 func TestLink(t *testing.T) {
103 50
 	iptable, _, filterChain := createNewChain(t)
104 51
 	ip1 := net.ParseIP("192.168.1.1")
... ...
@@ -210,7 +157,7 @@ func RunConcurrencyTest(t *testing.T, allowXlock bool) {
210 210
 	group := new(errgroup.Group)
211 211
 	for i := 0; i < 10; i++ {
212 212
 		group.Go(func() error {
213
-			return natChain.Forward(Append, ip, port, proto, dstAddr, dstPort, "lo")
213
+			return addSomeRules(natChain, ip, port, proto, dstAddr, dstPort)
214 214
 		})
215 215
 	}
216 216
 	if err := group.Wait(); err != nil {
... ...
@@ -218,6 +165,50 @@ func RunConcurrencyTest(t *testing.T, allowXlock bool) {
218 218
 	}
219 219
 }
220 220
 
221
+// addSomeRules adds arbitrary iptable rules. RunConcurrencyTest previously used
222
+// iptables.Forward to create rules, that function has been removed. To preserve
223
+// the test, this function creates similar rules.
224
+func addSomeRules(c *ChainInfo, ip net.IP, port int, proto, destAddr string, destPort int) error {
225
+	iptable := GetIptable(c.IPVersion)
226
+	daddr := ip.String()
227
+
228
+	args := []string{
229
+		"-p", proto,
230
+		"-d", daddr,
231
+		"--dport", strconv.Itoa(port),
232
+		"-j", "DNAT",
233
+		"--to-destination", net.JoinHostPort(destAddr, strconv.Itoa(destPort)),
234
+	}
235
+	if err := iptable.ProgramRule(Nat, c.Name, Append, args); err != nil {
236
+		return err
237
+	}
238
+
239
+	args = []string{
240
+		"!", "-i", "lo",
241
+		"-o", "lo",
242
+		"-p", proto,
243
+		"-d", destAddr,
244
+		"--dport", strconv.Itoa(destPort),
245
+		"-j", "ACCEPT",
246
+	}
247
+	if err := iptable.ProgramRule(Filter, c.Name, Append, args); err != nil {
248
+		return err
249
+	}
250
+
251
+	args = []string{
252
+		"-p", proto,
253
+		"-s", destAddr,
254
+		"-d", destAddr,
255
+		"--dport", strconv.Itoa(destPort),
256
+		"-j", "MASQUERADE",
257
+	}
258
+	if err := iptable.ProgramRule(Nat, "POSTROUTING", Append, args); err != nil {
259
+		return err
260
+	}
261
+
262
+	return nil
263
+}
264
+
221 265
 func TestCleanup(t *testing.T) {
222 266
 	iptable, _, filterChain := createNewChain(t)
223 267
 
... ...
@@ -256,7 +247,7 @@ func TestExistsRaw(t *testing.T) {
256 256
 
257 257
 	iptable := GetIptable(IPv4)
258 258
 
259
-	_, err := iptable.NewChain(testChain1, Filter, false)
259
+	_, err := iptable.NewChain(testChain1, Filter)
260 260
 	if err != nil {
261 261
 		t.Fatal(err)
262 262
 	}
... ...
@@ -264,7 +255,7 @@ func TestExistsRaw(t *testing.T) {
264 264
 		iptable.RemoveExistingChain(testChain1, Filter)
265 265
 	}()
266 266
 
267
-	_, err = iptable.NewChain(testChain2, Filter, false)
267
+	_, err = iptable.NewChain(testChain2, Filter)
268 268
 	if err != nil {
269 269
 		t.Fatal(err)
270 270
 	}