Browse code

Merge pull request #45987 from thaJeztah/cleanup_iptables_the_sequel

libnetwork/iptables: some cleanups and refactoring: the sequel

Sebastiaan van Stijn authored on 2023/07/19 21:38:12
Showing 5 changed files
... ...
@@ -376,7 +376,7 @@ func setINC(version iptables.IPVersion, iface string, enable bool) error {
376 376
 const oldIsolationChain = "DOCKER-ISOLATION"
377 377
 
378 378
 func removeIPChains(version iptables.IPVersion) {
379
-	ipt := iptables.IPTable{Version: version}
379
+	ipt := iptables.GetIptable(version)
380 380
 
381 381
 	// Remove obsolete rules from default chains
382 382
 	ipt.ProgramRule(iptables.Filter, "FORWARD", iptables.Delete, []string{"-j", oldIsolationChain})
... ...
@@ -110,6 +110,6 @@ func resetIptables(t *testing.T) {
110 110
 
111 111
 		_, err := iptable.Raw("-F", fwdChainName)
112 112
 		assert.Check(t, err)
113
-		_ = iptable.RemoveExistingChain(usrChainName, "")
113
+		_ = iptable.RemoveExistingChain(usrChainName, iptables.Filter)
114 114
 	}
115 115
 }
... ...
@@ -65,8 +65,8 @@ var (
65 65
 	onReloaded       []*func() // callbacks when Firewalld has been reloaded
66 66
 )
67 67
 
68
-// FirewalldInit initializes firewalld management code.
69
-func FirewalldInit() error {
68
+// firewalldInit initializes firewalld management code.
69
+func firewalldInit() error {
70 70
 	var err error
71 71
 
72 72
 	if connection, err = newConnection(); err != nil {
... ...
@@ -13,7 +13,7 @@ func TestFirewalldInit(t *testing.T) {
13 13
 	if !checkRunning() {
14 14
 		t.Skip("firewalld is not running")
15 15
 	}
16
-	if err := FirewalldInit(); err != nil {
16
+	if err := firewalldInit(); err != nil {
17 17
 		t.Fatal(err)
18 18
 	}
19 19
 }
... ...
@@ -21,15 +21,6 @@ import (
21 21
 // Action signifies the iptable action.
22 22
 type Action string
23 23
 
24
-// Policy is the default iptable policies
25
-type Policy string
26
-
27
-// Table refers to Nat, Filter or Mangle.
28
-type Table string
29
-
30
-// IPVersion refers to IP version, v4 or v6
31
-type IPVersion string
32
-
33 24
 const (
34 25
 	// Append appends the rule at the end of the chain.
35 26
 	Append Action = "-A"
... ...
@@ -37,19 +28,37 @@ const (
37 37
 	Delete Action = "-D"
38 38
 	// Insert inserts the rule at the top of the chain.
39 39
 	Insert Action = "-I"
40
+)
41
+
42
+// Policy is the default iptable policies
43
+type Policy string
44
+
45
+const (
46
+	// Drop is the default iptables DROP policy.
47
+	Drop Policy = "DROP"
48
+	// Accept is the default iptables ACCEPT policy.
49
+	Accept Policy = "ACCEPT"
50
+)
51
+
52
+// Table refers to Nat, Filter or Mangle.
53
+type Table string
54
+
55
+const (
40 56
 	// Nat table is used for nat translation rules.
41 57
 	Nat Table = "nat"
42 58
 	// Filter table is used for filter rules.
43 59
 	Filter Table = "filter"
44 60
 	// Mangle table is used for mangling the packet.
45 61
 	Mangle Table = "mangle"
46
-	// Drop is the default iptables DROP policy
47
-	Drop Policy = "DROP"
48
-	// Accept is the default iptables ACCEPT policy
49
-	Accept Policy = "ACCEPT"
50
-	// IPv4 is version 4
62
+)
63
+
64
+// IPVersion refers to IP version, v4 or v6
65
+type IPVersion string
66
+
67
+const (
68
+	// IPv4 is version 4.
51 69
 	IPv4 IPVersion = "IPV4"
52
-	// IPv6 is version 6
70
+	// IPv6 is version 6.
53 71
 	IPv6 IPVersion = "IPV6"
54 72
 )
55 73
 
... ...
@@ -57,15 +66,14 @@ var (
57 57
 	iptablesPath  string
58 58
 	ip6tablesPath string
59 59
 	supportsXlock = false
60
-	xLockWaitMsg  = "Another app is currently holding the xtables lock"
61 60
 	// used to lock iptables commands if xtables lock is not supported
62 61
 	bestEffortLock sync.Mutex
63 62
 	initOnce       sync.Once
64 63
 )
65 64
 
66
-// IPTable defines struct with IPVersion
65
+// IPTable defines struct with [IPVersion].
67 66
 type IPTable struct {
68
-	Version IPVersion
67
+	ipVersion IPVersion
69 68
 }
70 69
 
71 70
 // ChainInfo defines the iptables chain.
... ...
@@ -86,6 +94,19 @@ func (e ChainError) Error() string {
86 86
 	return fmt.Sprintf("error iptables %s: %s", e.Chain, string(e.Output))
87 87
 }
88 88
 
89
+// loopbackAddress returns the loopback address for the given IP version.
90
+func loopbackAddress(version IPVersion) string {
91
+	switch version {
92
+	case IPv4, "":
93
+		// IPv4 (default for backward-compatibility)
94
+		return "127.0.0.0/8"
95
+	case IPv6:
96
+		return "::1/128"
97
+	default:
98
+		panic("unknown IP version: " + version)
99
+	}
100
+}
101
+
89 102
 func detectIptables() {
90 103
 	path, err := exec.LookPath("iptables")
91 104
 	if err != nil {
... ...
@@ -117,7 +138,7 @@ func initFirewalld() {
117 117
 		log.G(context.TODO()).Info("skipping firewalld management for rootless mode")
118 118
 		return
119 119
 	}
120
-	if err := FirewalldInit(); err != nil {
120
+	if err := firewalldInit(); err != nil {
121 121
 		log.G(context.TODO()).WithError(err).Debugf("unable to initialize firewalld; using raw iptables instead")
122 122
 	}
123 123
 }
... ...
@@ -136,15 +157,28 @@ func initCheck() error {
136 136
 	return nil
137 137
 }
138 138
 
139
-// GetIptable returns an instance of IPTable with specified version
139
+// GetIptable returns an instance of IPTable with specified version ([IPv4]
140
+// or [IPv6]). It panics if an invalid [IPVersion] is provided.
140 141
 func GetIptable(version IPVersion) *IPTable {
141
-	return &IPTable{Version: version}
142
+	switch version {
143
+	case IPv4, IPv6:
144
+		// valid version
145
+	case "":
146
+		// default is IPv4 for backward-compatibility
147
+		version = IPv4
148
+	default:
149
+		panic("unknown IP version: " + version)
150
+	}
151
+	return &IPTable{ipVersion: version}
142 152
 }
143 153
 
144 154
 // NewChain adds a new chain to ip table.
145 155
 func (iptable IPTable) NewChain(name string, table Table, hairpinMode bool) (*ChainInfo, error) {
156
+	if name == "" {
157
+		return nil, fmt.Errorf("could not create chain: chain name is empty")
158
+	}
146 159
 	if table == "" {
147
-		table = Filter
160
+		return nil, fmt.Errorf("could not create chain %s: invalid table name: table name is empty", name)
148 161
 	}
149 162
 	// Add chain if it doesn't exist
150 163
 	if _, err := iptable.Raw("-t", string(table), "-n", "-L", name); err != nil {
... ...
@@ -158,18 +192,10 @@ func (iptable IPTable) NewChain(name string, table Table, hairpinMode bool) (*Ch
158 158
 		Name:        name,
159 159
 		Table:       table,
160 160
 		HairpinMode: hairpinMode,
161
-		IPVersion:   iptable.Version,
161
+		IPVersion:   iptable.ipVersion,
162 162
 	}, nil
163 163
 }
164 164
 
165
-// LoopbackByVersion returns loopback address by version
166
-func (iptable IPTable) LoopbackByVersion() string {
167
-	if iptable.Version == IPv6 {
168
-		return "::1/128"
169
-	}
170
-	return "127.0.0.0/8"
171
-}
172
-
173 165
 // ProgramChain is used to add rules to a chain
174 166
 func (iptable IPTable) ProgramChain(c *ChainInfo, bridgeName string, hairpinMode, enable bool) error {
175 167
 	if c.Name == "" {
... ...
@@ -211,7 +237,7 @@ func (iptable IPTable) ProgramChain(c *ChainInfo, bridgeName string, hairpinMode
211 211
 			"-j", c.Name,
212 212
 		}
213 213
 		if !hairpinMode {
214
-			output = append(output, "!", "--dst", iptable.LoopbackByVersion())
214
+			output = append(output, "!", "--dst", loopbackAddress(iptable.ipVersion))
215 215
 		}
216 216
 		if !iptable.Exists(Nat, "OUTPUT", output...) && enable {
217 217
 			if err := c.Output(Append, output...); err != nil {
... ...
@@ -272,13 +298,16 @@ func (iptable IPTable) ProgramChain(c *ChainInfo, bridgeName string, hairpinMode
272 272
 
273 273
 // RemoveExistingChain removes existing chain from the table.
274 274
 func (iptable IPTable) RemoveExistingChain(name string, table Table) error {
275
+	if name == "" {
276
+		return fmt.Errorf("could not remove chain: chain name is empty")
277
+	}
275 278
 	if table == "" {
276
-		table = Filter
279
+		return fmt.Errorf("could not remove chain %s: invalid table name: table name is empty", name)
277 280
 	}
278 281
 	c := &ChainInfo{
279 282
 		Name:      name,
280 283
 		Table:     table,
281
-		IPVersion: iptable.Version,
284
+		IPVersion: iptable.ipVersion,
282 285
 	}
283 286
 	return c.Remove()
284 287
 }
... ...
@@ -419,15 +448,15 @@ func (c *ChainInfo) Output(action Action, args ...string) error {
419 419
 
420 420
 // Remove removes the chain.
421 421
 func (c *ChainInfo) Remove() error {
422
-	iptable := GetIptable(c.IPVersion)
423 422
 	// Ignore errors - This could mean the chains were never set up
424 423
 	if c.Table == Nat {
425 424
 		_ = c.Prerouting(Delete, "-m", "addrtype", "--dst-type", "LOCAL", "-j", c.Name)
426
-		_ = c.Output(Delete, "-m", "addrtype", "--dst-type", "LOCAL", "!", "--dst", iptable.LoopbackByVersion(), "-j", c.Name)
425
+		_ = c.Output(Delete, "-m", "addrtype", "--dst-type", "LOCAL", "!", "--dst", loopbackAddress(c.IPVersion), "-j", c.Name)
427 426
 		_ = c.Output(Delete, "-m", "addrtype", "--dst-type", "LOCAL", "-j", c.Name) // Created in versions <= 0.1.6
428 427
 		_ = c.Prerouting(Delete)
429 428
 		_ = c.Output(Delete)
430 429
 	}
430
+	iptable := GetIptable(c.IPVersion)
431 431
 	_, _ = iptable.Raw("-t", string(c.Table), "-F", c.Name)
432 432
 	_, _ = iptable.Raw("-t", string(c.Table), "-X", c.Name)
433 433
 	return nil
... ...
@@ -465,14 +494,17 @@ func (iptable IPTable) exists(native bool, table Table, chain string, rule ...st
465 465
 	return err == nil
466 466
 }
467 467
 
468
-// Maximum duration that an iptables operation can take
469
-// before flagging a warning.
470
-const opWarnTime = 2 * time.Second
468
+const (
469
+	// opWarnTime is the maximum duration that an iptables operation can take before flagging a warning.
470
+	opWarnTime = 2 * time.Second
471
+
472
+	// xLockWaitMsg is the iptables warning about xtables lock that can be suppressed.
473
+	xLockWaitMsg = "Another app is currently holding the xtables lock"
474
+)
471 475
 
472 476
 func filterOutput(start time.Time, output []byte, args ...string) []byte {
473
-	// Flag operations that have taken a long time to complete
474
-	opTime := time.Since(start)
475
-	if opTime > opWarnTime {
477
+	if opTime := time.Since(start); opTime > opWarnTime {
478
+		// Flag operations that have taken a long time to complete
476 479
 		log.G(context.TODO()).Warnf("xtables contention detected while running [%s]: Waited for %.2f seconds and received %q", strings.Join(args, " "), float64(opTime)/float64(time.Second), string(output))
477 480
 	}
478 481
 	// ignore iptables' message about xtables lock:
... ...
@@ -489,7 +521,7 @@ func (iptable IPTable) Raw(args ...string) ([]byte, error) {
489 489
 	if firewalldRunning {
490 490
 		// select correct IP version for firewalld
491 491
 		ipv := Iptables
492
-		if iptable.Version == IPv6 {
492
+		if iptable.ipVersion == IPv6 {
493 493
 			ipv = IP6Tables
494 494
 		}
495 495
 
... ...
@@ -506,16 +538,9 @@ func (iptable IPTable) raw(args ...string) ([]byte, error) {
506 506
 	if err := initCheck(); err != nil {
507 507
 		return nil, err
508 508
 	}
509
-	if supportsXlock {
510
-		args = append([]string{"--wait"}, args...)
511
-	} else {
512
-		bestEffortLock.Lock()
513
-		defer bestEffortLock.Unlock()
514
-	}
515
-
516 509
 	path := iptablesPath
517 510
 	commandName := "iptables"
518
-	if iptable.Version == IPv6 {
511
+	if iptable.ipVersion == IPv6 {
519 512
 		if ip6tablesPath == "" {
520 513
 			return nil, fmt.Errorf("ip6tables is missing")
521 514
 		}
... ...
@@ -523,6 +548,13 @@ func (iptable IPTable) raw(args ...string) ([]byte, error) {
523 523
 		commandName = "ip6tables"
524 524
 	}
525 525
 
526
+	if supportsXlock {
527
+		args = append([]string{"--wait"}, args...)
528
+	} else {
529
+		bestEffortLock.Lock()
530
+		defer bestEffortLock.Unlock()
531
+	}
532
+
526 533
 	log.G(context.TODO()).Debugf("%s, %v", path, args)
527 534
 
528 535
 	startTime := time.Now()
... ...
@@ -554,10 +586,8 @@ func (iptable IPTable) RawCombinedOutputNative(args ...string) error {
554 554
 
555 555
 // ExistChain checks if a chain exists
556 556
 func (iptable IPTable) ExistChain(chain string, table Table) bool {
557
-	if _, err := iptable.Raw("-t", string(table), "-nL", chain); err == nil {
558
-		return true
559
-	}
560
-	return false
557
+	_, err := iptable.Raw("-t", string(table), "-nL", chain)
558
+	return err == nil
561 559
 }
562 560
 
563 561
 // SetDefaultPolicy sets the passed default policy for the table/chain
... ...
@@ -573,28 +603,21 @@ func (iptable IPTable) AddReturnRule(chain string) error {
573 573
 	if iptable.Exists(Filter, chain, "-j", "RETURN") {
574 574
 		return nil
575 575
 	}
576
-
577
-	err := iptable.RawCombinedOutput("-A", chain, "-j", "RETURN")
578
-	if err != nil {
576
+	if err := iptable.RawCombinedOutput("-A", chain, "-j", "RETURN"); err != nil {
579 577
 		return fmt.Errorf("unable to add return rule in %s chain: %v", chain, err)
580 578
 	}
581
-
582 579
 	return nil
583 580
 }
584 581
 
585 582
 // EnsureJumpRule ensures the jump rule is on top
586 583
 func (iptable IPTable) EnsureJumpRule(fromChain, toChain string) error {
587 584
 	if iptable.Exists(Filter, fromChain, "-j", toChain) {
588
-		err := iptable.RawCombinedOutput("-D", fromChain, "-j", toChain)
589
-		if err != nil {
585
+		if err := iptable.RawCombinedOutput("-D", fromChain, "-j", toChain); err != nil {
590 586
 			return fmt.Errorf("unable to remove jump to %s rule in %s chain: %v", toChain, fromChain, err)
591 587
 		}
592 588
 	}
593
-
594
-	err := iptable.RawCombinedOutput("-I", fromChain, "-j", toChain)
595
-	if err != nil {
589
+	if err := iptable.RawCombinedOutput("-I", fromChain, "-j", toChain); err != nil {
596 590
 		return fmt.Errorf("unable to insert jump to %s rule in %s chain: %v", toChain, fromChain, err)
597 591
 	}
598
-
599 592
 	return nil
600 593
 }