Browse code

libnetwork_test: overhaul TestParallel

TestParallel has been written in an unusual style which relies on the
testing package's intra-test parallelism feature and lots of global
state to test one thing using three cooperating parallel tests. It is
complicated to reason about and quite brittle. For example, the command

go test -run TestParallel1 ./libnetwork

would deadlock, waiting until the test timeout for TestParallel2 and
TestParallel3 to run. And the test would be skipped if the
'-test.parallel' flag was less than three, either explicitly or
implicitly (default: GOMAXPROCS).

Overhaul TestParallel to address the aforementioned deficiencies and
get rid of mutable global state.

Signed-off-by: Cory Snider <csnider@mirantis.com>

Cory Snider authored on 2022/11/09 07:02:32
Showing 2 changed files
... ...
@@ -3,13 +3,10 @@ package libnetwork_test
3 3
 import (
4 4
 	"bytes"
5 5
 	"encoding/json"
6
-	"flag"
7 6
 	"fmt"
8 7
 	"net"
9 8
 	"os"
10 9
 	"os/exec"
11
-	"runtime"
12
-	"strconv"
13 10
 	"strings"
14 11
 	"sync"
15 12
 	"testing"
... ...
@@ -22,23 +19,17 @@ import (
22 22
 	"github.com/docker/docker/libnetwork/testutils"
23 23
 	"github.com/docker/docker/libnetwork/types"
24 24
 	"github.com/docker/docker/pkg/reexec"
25
+	"github.com/pkg/errors"
25 26
 	"github.com/sirupsen/logrus"
26 27
 	"github.com/vishvananda/netlink"
27 28
 	"github.com/vishvananda/netns"
29
+	"golang.org/x/sync/errgroup"
28 30
 )
29 31
 
30 32
 const (
31 33
 	bridgeNetType = "bridge"
32 34
 )
33 35
 
34
-// Shared state for createGlobalInstance() and runParallelTests().
35
-var (
36
-	origins = netns.None()
37
-	testns  = netns.None()
38
-
39
-	controller libnetwork.NetworkController
40
-)
41
-
42 36
 func makeTesthostNetwork(t *testing.T, c libnetwork.NetworkController) libnetwork.Network {
43 37
 	t.Helper()
44 38
 	n, err := createTestNetwork(c, "host", "testhost", options.Generic{}, nil, nil)
... ...
@@ -48,60 +39,6 @@ func makeTesthostNetwork(t *testing.T, c libnetwork.NetworkController) libnetwor
48 48
 	return n
49 49
 }
50 50
 
51
-func createGlobalInstance(t *testing.T) {
52
-	var err error
53
-	defer close(start)
54
-
55
-	origins, err = netns.Get()
56
-	if err != nil {
57
-		t.Fatal(err)
58
-	}
59
-
60
-	testns, err = netns.New()
61
-	if err != nil {
62
-		t.Fatal(err)
63
-	}
64
-
65
-	controller = newController(t)
66
-	t.Cleanup(controller.Stop)
67
-
68
-	netOption := options.Generic{
69
-		netlabel.GenericData: options.Generic{
70
-			"BridgeName": "network",
71
-		},
72
-	}
73
-
74
-	net1 := makeTesthostNetwork(t, controller)
75
-	net2, err := createTestNetwork(controller, "bridge", "network2", netOption, nil, nil)
76
-	if err != nil {
77
-		t.Fatal(err)
78
-	}
79
-
80
-	_, err = net1.CreateEndpoint("pep1")
81
-	if err != nil {
82
-		t.Fatal(err)
83
-	}
84
-
85
-	_, err = net2.CreateEndpoint("pep2")
86
-	if err != nil {
87
-		t.Fatal(err)
88
-	}
89
-
90
-	_, err = net2.CreateEndpoint("pep3")
91
-	if err != nil {
92
-		t.Fatal(err)
93
-	}
94
-
95
-	if sboxes[first-1], err = controller.NewSandbox(fmt.Sprintf("%drace", first), libnetwork.OptionUseDefaultSandbox()); err != nil {
96
-		t.Fatal(err)
97
-	}
98
-	for thd := first + 1; thd <= last; thd++ {
99
-		if sboxes[thd-1], err = controller.NewSandbox(fmt.Sprintf("%drace", thd)); err != nil {
100
-			t.Fatal(err)
101
-		}
102
-	}
103
-}
104
-
105 51
 func TestHost(t *testing.T) {
106 52
 	defer testutils.SetupTestOSContext(t)()
107 53
 	controller := newController(t)
... ...
@@ -906,160 +843,133 @@ func TestResolvConf(t *testing.T) {
906 906
 	}
907 907
 }
908 908
 
909
-func parallelJoin(t *testing.T, rc libnetwork.Sandbox, ep libnetwork.Endpoint, thrNumber int) {
910
-	debugf("J%d.", thrNumber)
911
-	var err error
912
-
913
-	sb := sboxes[thrNumber-1]
914
-	err = ep.Join(sb)
915
-
916
-	if err != nil {
917
-		if _, ok := err.(types.ForbiddenError); !ok {
918
-			t.Fatalf("thread %d: %v", thrNumber, err)
919
-		}
920
-		debugf("JE%d(%v).", thrNumber, err)
921
-	}
922
-	debugf("JD%d.", thrNumber)
923
-}
924
-
925
-func parallelLeave(t *testing.T, rc libnetwork.Sandbox, ep libnetwork.Endpoint, thrNumber int) {
926
-	debugf("L%d.", thrNumber)
927
-	var err error
928
-
929
-	sb := sboxes[thrNumber-1]
930
-
931
-	err = ep.Leave(sb)
932
-	if err != nil {
933
-		if _, ok := err.(types.ForbiddenError); !ok {
934
-			t.Fatalf("thread %d: %v", thrNumber, err)
935
-		}
936
-		debugf("LE%d(%v).", thrNumber, err)
937
-	}
938
-	debugf("LD%d.", thrNumber)
909
+type parallelTester struct {
910
+	osctx      *testutils.OSContext
911
+	controller libnetwork.NetworkController
912
+	net1, net2 libnetwork.Network
913
+	iterCnt    int
939 914
 }
940 915
 
941
-func runParallelTests(t *testing.T, thrNumber int) {
916
+func (pt parallelTester) Do(t *testing.T, thrNumber int) error {
942 917
 	var (
943 918
 		ep  libnetwork.Endpoint
944 919
 		sb  libnetwork.Sandbox
945 920
 		err error
946 921
 	)
947 922
 
948
-	t.Parallel()
923
+	teardown, err := pt.osctx.Set()
924
+	if err != nil {
925
+		return err
926
+	}
927
+	defer teardown(t)
949 928
 
950
-	pTest := flag.Lookup("test.parallel")
951
-	if pTest == nil {
952
-		t.Skip("Skipped because test.parallel flag not set;")
929
+	epName := fmt.Sprintf("pep%d", thrNumber)
930
+
931
+	if thrNumber == 1 {
932
+		ep, err = pt.net1.EndpointByName(epName)
933
+	} else {
934
+		ep, err = pt.net2.EndpointByName(epName)
953 935
 	}
954
-	numParallel, err := strconv.Atoi(pTest.Value.String())
936
+
955 937
 	if err != nil {
956
-		t.Fatal(err)
938
+		return errors.WithStack(err)
957 939
 	}
958
-	if numParallel < numThreads {
959
-		t.Skip("Skipped because t.parallel was less than ", numThreads)
940
+	if ep == nil {
941
+		return errors.New("got nil ep with no error")
960 942
 	}
961 943
 
962
-	runtime.LockOSThread()
963
-	if thrNumber == first {
964
-		createGlobalInstance(t)
965
-	} else {
966
-		<-start
967
-
968
-		thrdone := make(chan struct{})
969
-		done <- thrdone
970
-		defer close(thrdone)
944
+	cid := fmt.Sprintf("%drace", thrNumber)
945
+	pt.controller.WalkSandboxes(libnetwork.SandboxContainerWalker(&sb, cid))
946
+	if sb == nil {
947
+		return errors.Errorf("got nil sandbox for container: %s", cid)
948
+	}
971 949
 
972
-		if thrNumber == last {
973
-			defer close(done)
950
+	for i := 0; i < pt.iterCnt; i++ {
951
+		if err := ep.Join(sb); err != nil {
952
+			if _, ok := err.(types.ForbiddenError); !ok {
953
+				return errors.Wrapf(err, "thread %d", thrNumber)
954
+			}
974 955
 		}
975
-
976
-		err = netns.Set(testns)
977
-		if err != nil {
978
-			runtime.UnlockOSThread()
979
-			t.Fatal(err)
956
+		if err := ep.Leave(sb); err != nil {
957
+			if _, ok := err.(types.ForbiddenError); !ok {
958
+				return errors.Wrapf(err, "thread %d", thrNumber)
959
+			}
980 960
 		}
981 961
 	}
982
-	defer func() {
983
-		if err := netns.Set(origins); err != nil {
984
-			t.Fatalf("Error restoring the current thread's netns: %v", err)
985
-		} else {
986
-			runtime.UnlockOSThread()
987
-		}
988
-	}()
989 962
 
990
-	net1, err := controller.NetworkByName("testhost")
991
-	if err != nil {
992
-		t.Fatal(err)
993
-	}
994
-	if net1 == nil {
995
-		t.Fatal("Could not find testhost")
963
+	if err := errors.WithStack(sb.Delete()); err != nil {
964
+		return err
996 965
 	}
966
+	return errors.WithStack(ep.Delete(false))
967
+}
997 968
 
998
-	net2, err := controller.NetworkByName("network2")
999
-	if err != nil {
1000
-		t.Fatal(err)
1001
-	}
1002
-	if net2 == nil {
1003
-		t.Fatal("Could not find network2")
1004
-	}
969
+func TestParallel(t *testing.T) {
970
+	const (
971
+		first      = 1
972
+		last       = 3
973
+		numThreads = last - first + 1
974
+		iterCnt    = 25
975
+	)
1005 976
 
1006
-	epName := fmt.Sprintf("pep%d", thrNumber)
977
+	osctx := testutils.SetupTestOSContextEx(t)
978
+	defer osctx.Cleanup(t)
979
+	controller := newController(t)
1007 980
 
1008
-	if thrNumber == first {
1009
-		ep, err = net1.EndpointByName(epName)
1010
-	} else {
1011
-		ep, err = net2.EndpointByName(epName)
981
+	netOption := options.Generic{
982
+		netlabel.GenericData: options.Generic{
983
+			"BridgeName": "network",
984
+		},
1012 985
 	}
1013 986
 
987
+	net1 := makeTesthostNetwork(t, controller)
988
+	defer net1.Delete()
989
+	net2, err := createTestNetwork(controller, "bridge", "network2", netOption, nil, nil)
1014 990
 	if err != nil {
1015 991
 		t.Fatal(err)
1016 992
 	}
1017
-	if ep == nil {
1018
-		t.Fatal("Got nil ep with no error")
1019
-	}
993
+	defer net2.Delete()
1020 994
 
1021
-	cid := fmt.Sprintf("%drace", thrNumber)
1022
-	controller.WalkSandboxes(libnetwork.SandboxContainerWalker(&sb, cid))
1023
-	if sb == nil {
1024
-		t.Fatalf("Got nil sandbox for container: %s", cid)
995
+	_, err = net1.CreateEndpoint("pep1")
996
+	if err != nil {
997
+		t.Fatal(err)
1025 998
 	}
1026 999
 
1027
-	for i := 0; i < iterCnt; i++ {
1028
-		parallelJoin(t, sb, ep, thrNumber)
1029
-		parallelLeave(t, sb, ep, thrNumber)
1000
+	_, err = net2.CreateEndpoint("pep2")
1001
+	if err != nil {
1002
+		t.Fatal(err)
1030 1003
 	}
1031 1004
 
1032
-	debugf("\n")
1033
-
1034
-	err = sb.Delete()
1005
+	_, err = net2.CreateEndpoint("pep3")
1035 1006
 	if err != nil {
1036 1007
 		t.Fatal(err)
1037 1008
 	}
1038
-	if thrNumber == first {
1039
-		for thrdone := range done {
1040
-			<-thrdone
1041
-		}
1042 1009
 
1043
-		if testns != origins {
1044
-			testns.Close()
1045
-		}
1046
-		if err := net2.Delete(); err != nil {
1047
-			t.Fatal(err)
1048
-		}
1049
-	} else {
1050
-		err = ep.Delete(false)
1051
-		if err != nil {
1010
+	sboxes := make([]libnetwork.Sandbox, numThreads)
1011
+	if sboxes[first-1], err = controller.NewSandbox(fmt.Sprintf("%drace", first), libnetwork.OptionUseDefaultSandbox()); err != nil {
1012
+		t.Fatal(err)
1013
+	}
1014
+	for thd := first + 1; thd <= last; thd++ {
1015
+		if sboxes[thd-1], err = controller.NewSandbox(fmt.Sprintf("%drace", thd)); err != nil {
1052 1016
 			t.Fatal(err)
1053 1017
 		}
1054 1018
 	}
1055
-}
1056 1019
 
1057
-func TestParallel1(t *testing.T) {
1058
-	runParallelTests(t, 1)
1059
-}
1020
+	pt := parallelTester{
1021
+		osctx:      osctx,
1022
+		controller: controller,
1023
+		net1:       net1,
1024
+		net2:       net2,
1025
+		iterCnt:    iterCnt,
1026
+	}
1060 1027
 
1061
-func TestParallel2(t *testing.T) {
1062
-	runParallelTests(t, 2)
1028
+	var eg errgroup.Group
1029
+	for i := first; i <= last; i++ {
1030
+		i := i
1031
+		eg.Go(func() error { return pt.Do(t, i) })
1032
+	}
1033
+	if err := eg.Wait(); err != nil {
1034
+		t.Fatalf("%+v", err)
1035
+	}
1063 1036
 }
1064 1037
 
1065 1038
 func TestBridge(t *testing.T) {
... ...
@@ -1150,10 +1060,6 @@ func isV6Listenable() bool {
1150 1150
 	return v6ListenableCached
1151 1151
 }
1152 1152
 
1153
-func TestParallel3(t *testing.T) {
1154
-	runParallelTests(t, 3)
1155
-}
1156
-
1157 1153
 func TestNullIpam(t *testing.T) {
1158 1154
 	defer testutils.SetupTestOSContext(t)()
1159 1155
 	controller := newController(t)
... ...
@@ -1332,23 +1332,3 @@ func TestValidRemoteDriver(t *testing.T) {
1332 1332
 		}
1333 1333
 	}()
1334 1334
 }
1335
-
1336
-var (
1337
-	start  = make(chan struct{})
1338
-	done   = make(chan chan struct{}, numThreads-1)
1339
-	sboxes = make([]libnetwork.Sandbox, numThreads)
1340
-)
1341
-
1342
-const (
1343
-	iterCnt    = 25
1344
-	numThreads = 3
1345
-	first      = 1
1346
-	last       = numThreads
1347
-	debug      = false
1348
-)
1349
-
1350
-func debugf(format string, a ...interface{}) {
1351
-	if debug {
1352
-		fmt.Printf(format, a...)
1353
-	}
1354
-}