Browse code

test: use `T.Setenv` to set env vars in tests

This commit replaces `os.Setenv` with `t.Setenv` in tests. The
environment variable is automatically restored to its original value
when the test and all its subtests complete.

Reference: https://pkg.go.dev/testing#T.Setenv
Signed-off-by: Eng Zer Jun <engzerjun@gmail.com>

Eng Zer Jun authored on 2022/04/23 18:01:58
Showing 10 changed files
... ...
@@ -6,7 +6,6 @@ import (
6 6
 	"io"
7 7
 	"net/http"
8 8
 	"net/url"
9
-	"os"
10 9
 	"runtime"
11 10
 	"strings"
12 11
 	"testing"
... ...
@@ -189,7 +188,7 @@ func TestNewClientWithOpsFromEnvSetsDefaultVersion(t *testing.T) {
189 189
 	assert.Check(t, is.Equal(client.ClientVersion(), api.DefaultVersion))
190 190
 
191 191
 	const expected = "1.22"
192
-	_ = os.Setenv("DOCKER_API_VERSION", expected)
192
+	t.Setenv("DOCKER_API_VERSION", expected)
193 193
 	client, err = NewClientWithOpts(FromEnv)
194 194
 	if err != nil {
195 195
 		t.Fatal(err)
... ...
@@ -1654,8 +1654,7 @@ func BenchmarkUnwrapEvents(b *testing.B) {
1654 1654
 
1655 1655
 func TestNewAWSLogsClientCredentialEndpointDetect(t *testing.T) {
1656 1656
 	// required for the cloudwatchlogs client
1657
-	os.Setenv("AWS_REGION", "us-west-2")
1658
-	defer os.Unsetenv("AWS_REGION")
1657
+	t.Setenv("AWS_REGION", "us-west-2")
1659 1658
 
1660 1659
 	credsResp := `{
1661 1660
 		"AccessKeyId" :    "test-access-key-id",
... ...
@@ -1694,17 +1693,13 @@ func TestNewAWSLogsClientCredentialEndpointDetect(t *testing.T) {
1694 1694
 
1695 1695
 func TestNewAWSLogsClientCredentialEnvironmentVariable(t *testing.T) {
1696 1696
 	// required for the cloudwatchlogs client
1697
-	os.Setenv("AWS_REGION", "us-west-2")
1698
-	defer os.Unsetenv("AWS_REGION")
1697
+	t.Setenv("AWS_REGION", "us-west-2")
1699 1698
 
1700 1699
 	expectedAccessKeyID := "test-access-key-id"
1701 1700
 	expectedSecretAccessKey := "test-secret-access-key"
1702 1701
 
1703
-	os.Setenv("AWS_ACCESS_KEY_ID", expectedAccessKeyID)
1704
-	defer os.Unsetenv("AWS_ACCESS_KEY_ID")
1705
-
1706
-	os.Setenv("AWS_SECRET_ACCESS_KEY", expectedSecretAccessKey)
1707
-	defer os.Unsetenv("AWS_SECRET_ACCESS_KEY")
1702
+	t.Setenv("AWS_ACCESS_KEY_ID", expectedAccessKeyID)
1703
+	t.Setenv("AWS_SECRET_ACCESS_KEY", expectedSecretAccessKey)
1708 1704
 
1709 1705
 	info := logger.Info{
1710 1706
 		Config: map[string]string{},
... ...
@@ -1724,8 +1719,7 @@ func TestNewAWSLogsClientCredentialEnvironmentVariable(t *testing.T) {
1724 1724
 
1725 1725
 func TestNewAWSLogsClientCredentialSharedFile(t *testing.T) {
1726 1726
 	// required for the cloudwatchlogs client
1727
-	os.Setenv("AWS_REGION", "us-west-2")
1728
-	defer os.Unsetenv("AWS_REGION")
1727
+	t.Setenv("AWS_REGION", "us-west-2")
1729 1728
 
1730 1729
 	expectedAccessKeyID := "test-access-key-id"
1731 1730
 	expectedSecretAccessKey := "test-secret-access-key"
... ...
@@ -1750,8 +1744,7 @@ func TestNewAWSLogsClientCredentialSharedFile(t *testing.T) {
1750 1750
 	os.Unsetenv("AWS_ACCESS_KEY_ID")
1751 1751
 	os.Unsetenv("AWS_SECRET_ACCESS_KEY")
1752 1752
 
1753
-	os.Setenv("AWS_SHARED_CREDENTIALS_FILE", tmpfile.Name())
1754
-	defer os.Unsetenv("AWS_SHARED_CREDENTIALS_FILE")
1753
+	t.Setenv("AWS_SHARED_CREDENTIALS_FILE", tmpfile.Name())
1755 1754
 
1756 1755
 	info := logger.Info{
1757 1756
 		Config: map[string]string{},
... ...
@@ -5,7 +5,6 @@ import (
5 5
 	"context"
6 6
 	"fmt"
7 7
 	"net/http"
8
-	"os"
9 8
 	"runtime"
10 9
 	"testing"
11 10
 	"time"
... ...
@@ -807,9 +806,7 @@ func TestRawFormatWithoutTag(t *testing.T) {
807 807
 // Verify that we will send messages in batches with default batching parameters,
808 808
 // but change frequency to be sure that numOfRequests will match expected 17 requests
809 809
 func TestBatching(t *testing.T) {
810
-	if err := os.Setenv(envVarPostMessagesFrequency, "10h"); err != nil {
811
-		t.Fatal(err)
812
-	}
810
+	t.Setenv(envVarPostMessagesFrequency, "10h")
813 811
 
814 812
 	hec := NewHTTPEventCollectorMock(t)
815 813
 
... ...
@@ -865,17 +862,11 @@ func TestBatching(t *testing.T) {
865 865
 	if err != nil {
866 866
 		t.Fatal(err)
867 867
 	}
868
-
869
-	if err := os.Setenv(envVarPostMessagesFrequency, ""); err != nil {
870
-		t.Fatal(err)
871
-	}
872 868
 }
873 869
 
874 870
 // Verify that test is using time to fire events not rare than specified frequency
875 871
 func TestFrequency(t *testing.T) {
876
-	if err := os.Setenv(envVarPostMessagesFrequency, "5ms"); err != nil {
877
-		t.Fatal(err)
878
-	}
872
+	t.Setenv(envVarPostMessagesFrequency, "5ms")
879 873
 
880 874
 	hec := NewHTTPEventCollectorMock(t)
881 875
 
... ...
@@ -938,30 +929,15 @@ func TestFrequency(t *testing.T) {
938 938
 	if err != nil {
939 939
 		t.Fatal(err)
940 940
 	}
941
-
942
-	if err := os.Setenv(envVarPostMessagesFrequency, ""); err != nil {
943
-		t.Fatal(err)
944
-	}
945 941
 }
946 942
 
947 943
 // Simulate behavior similar to first version of Splunk Logging Driver, when we were sending one message
948 944
 // per request
949 945
 func TestOneMessagePerRequest(t *testing.T) {
950
-	if err := os.Setenv(envVarPostMessagesFrequency, "10h"); err != nil {
951
-		t.Fatal(err)
952
-	}
953
-
954
-	if err := os.Setenv(envVarPostMessagesBatchSize, "1"); err != nil {
955
-		t.Fatal(err)
956
-	}
957
-
958
-	if err := os.Setenv(envVarBufferMaximum, "1"); err != nil {
959
-		t.Fatal(err)
960
-	}
961
-
962
-	if err := os.Setenv(envVarStreamChannelSize, "0"); err != nil {
963
-		t.Fatal(err)
964
-	}
946
+	t.Setenv(envVarPostMessagesFrequency, "10h")
947
+	t.Setenv(envVarPostMessagesBatchSize, "1")
948
+	t.Setenv(envVarBufferMaximum, "1")
949
+	t.Setenv(envVarStreamChannelSize, "0")
965 950
 
966 951
 	hec := NewHTTPEventCollectorMock(t)
967 952
 
... ...
@@ -1017,22 +993,6 @@ func TestOneMessagePerRequest(t *testing.T) {
1017 1017
 	if err != nil {
1018 1018
 		t.Fatal(err)
1019 1019
 	}
1020
-
1021
-	if err := os.Setenv(envVarPostMessagesFrequency, ""); err != nil {
1022
-		t.Fatal(err)
1023
-	}
1024
-
1025
-	if err := os.Setenv(envVarPostMessagesBatchSize, ""); err != nil {
1026
-		t.Fatal(err)
1027
-	}
1028
-
1029
-	if err := os.Setenv(envVarBufferMaximum, ""); err != nil {
1030
-		t.Fatal(err)
1031
-	}
1032
-
1033
-	if err := os.Setenv(envVarStreamChannelSize, ""); err != nil {
1034
-		t.Fatal(err)
1035
-	}
1036 1020
 }
1037 1021
 
1038 1022
 // Driver should not be created when HEC is unresponsive
... ...
@@ -1136,17 +1096,9 @@ func TestSkipVerify(t *testing.T) {
1136 1136
 
1137 1137
 // Verify logic for when we filled whole buffer
1138 1138
 func TestBufferMaximum(t *testing.T) {
1139
-	if err := os.Setenv(envVarPostMessagesBatchSize, "2"); err != nil {
1140
-		t.Fatal(err)
1141
-	}
1142
-
1143
-	if err := os.Setenv(envVarBufferMaximum, "10"); err != nil {
1144
-		t.Fatal(err)
1145
-	}
1146
-
1147
-	if err := os.Setenv(envVarStreamChannelSize, "0"); err != nil {
1148
-		t.Fatal(err)
1149
-	}
1139
+	t.Setenv(envVarPostMessagesBatchSize, "2")
1140
+	t.Setenv(envVarBufferMaximum, "10")
1141
+	t.Setenv(envVarStreamChannelSize, "0")
1150 1142
 
1151 1143
 	hec := NewHTTPEventCollectorMock(t)
1152 1144
 	hec.simulateErr(true)
... ...
@@ -1209,33 +1161,13 @@ func TestBufferMaximum(t *testing.T) {
1209 1209
 	if err != nil {
1210 1210
 		t.Fatal(err)
1211 1211
 	}
1212
-
1213
-	if err := os.Setenv(envVarPostMessagesBatchSize, ""); err != nil {
1214
-		t.Fatal(err)
1215
-	}
1216
-
1217
-	if err := os.Setenv(envVarBufferMaximum, ""); err != nil {
1218
-		t.Fatal(err)
1219
-	}
1220
-
1221
-	if err := os.Setenv(envVarStreamChannelSize, ""); err != nil {
1222
-		t.Fatal(err)
1223
-	}
1224 1212
 }
1225 1213
 
1226 1214
 // Verify that we are not blocking close when HEC is down for the whole time
1227 1215
 func TestServerAlwaysDown(t *testing.T) {
1228
-	if err := os.Setenv(envVarPostMessagesBatchSize, "2"); err != nil {
1229
-		t.Fatal(err)
1230
-	}
1231
-
1232
-	if err := os.Setenv(envVarBufferMaximum, "4"); err != nil {
1233
-		t.Fatal(err)
1234
-	}
1235
-
1236
-	if err := os.Setenv(envVarStreamChannelSize, "0"); err != nil {
1237
-		t.Fatal(err)
1238
-	}
1216
+	t.Setenv(envVarPostMessagesBatchSize, "2")
1217
+	t.Setenv(envVarBufferMaximum, "4")
1218
+	t.Setenv(envVarStreamChannelSize, "0")
1239 1219
 
1240 1220
 	hec := NewHTTPEventCollectorMock(t)
1241 1221
 	hec.simulateServerError = true
... ...
@@ -1281,18 +1213,6 @@ func TestServerAlwaysDown(t *testing.T) {
1281 1281
 	if err != nil {
1282 1282
 		t.Fatal(err)
1283 1283
 	}
1284
-
1285
-	if err := os.Setenv(envVarPostMessagesBatchSize, ""); err != nil {
1286
-		t.Fatal(err)
1287
-	}
1288
-
1289
-	if err := os.Setenv(envVarBufferMaximum, ""); err != nil {
1290
-		t.Fatal(err)
1291
-	}
1292
-
1293
-	if err := os.Setenv(envVarStreamChannelSize, ""); err != nil {
1294
-		t.Fatal(err)
1295
-	}
1296 1284
 }
1297 1285
 
1298 1286
 // Cannot send messages after we close driver
... ...
@@ -5009,16 +5009,14 @@ func (s *DockerRegistryAuthHtpasswdSuite) TestBuildFromAuthenticatedRegistry(c *
5009 5009
 }
5010 5010
 
5011 5011
 func (s *DockerRegistryAuthHtpasswdSuite) TestBuildWithExternalAuth(c *testing.T) {
5012
-	osPath := os.Getenv("PATH")
5013
-	defer os.Setenv("PATH", osPath)
5014
-
5015 5012
 	workingDir, err := os.Getwd()
5016 5013
 	assert.NilError(c, err)
5017 5014
 	absolute, err := filepath.Abs(filepath.Join(workingDir, "fixtures", "auth"))
5018 5015
 	assert.NilError(c, err)
5019
-	testPath := fmt.Sprintf("%s%c%s", osPath, filepath.ListSeparator, absolute)
5020 5016
 
5021
-	os.Setenv("PATH", testPath)
5017
+	osPath := os.Getenv("PATH")
5018
+	testPath := fmt.Sprintf("%s%c%s", osPath, filepath.ListSeparator, absolute)
5019
+	c.Setenv("PATH", testPath)
5022 5020
 
5023 5021
 	repoName := fmt.Sprintf("%v/dockercli/busybox:authtest", privateRegistryURL)
5024 5022
 
... ...
@@ -1693,12 +1693,7 @@ func (s *DockerDaemonSuite) TestDaemonStartWithDefaultTLSHost(c *testing.T) {
1693 1693
 		"--tlskey", "fixtures/https/server-key.pem")
1694 1694
 
1695 1695
 	// The client with --tlsverify should also use default host localhost:2376
1696
-	tmpHost := os.Getenv("DOCKER_HOST")
1697
-	defer func() {
1698
-		os.Setenv("DOCKER_HOST", tmpHost)
1699
-	}()
1700
-
1701
-	os.Setenv("DOCKER_HOST", "")
1696
+	c.Setenv("DOCKER_HOST", "")
1702 1697
 
1703 1698
 	out, _ := dockerCmd(
1704 1699
 		c,
... ...
@@ -15,16 +15,14 @@ import (
15 15
 func (s *DockerRegistryAuthHtpasswdSuite) TestLogoutWithExternalAuth(c *testing.T) {
16 16
 	s.d.StartWithBusybox(c)
17 17
 
18
-	osPath := os.Getenv("PATH")
19
-	defer os.Setenv("PATH", osPath)
20
-
21 18
 	workingDir, err := os.Getwd()
22 19
 	assert.NilError(c, err)
23 20
 	absolute, err := filepath.Abs(filepath.Join(workingDir, "fixtures", "auth"))
24 21
 	assert.NilError(c, err)
25
-	testPath := fmt.Sprintf("%s%c%s", osPath, filepath.ListSeparator, absolute)
26 22
 
27
-	os.Setenv("PATH", testPath)
23
+	osPath := os.Getenv("PATH")
24
+	testPath := fmt.Sprintf("%s%c%s", osPath, filepath.ListSeparator, absolute)
25
+	c.Setenv("PATH", testPath)
28 26
 
29 27
 	repoName := fmt.Sprintf("%v/dockercli/busybox:authtest", privateRegistryURL)
30 28
 
... ...
@@ -65,16 +63,14 @@ func (s *DockerRegistryAuthHtpasswdSuite) TestLogoutWithExternalAuth(c *testing.
65 65
 
66 66
 // #23100
67 67
 func (s *DockerRegistryAuthHtpasswdSuite) TestLogoutWithWrongHostnamesStored(c *testing.T) {
68
-	osPath := os.Getenv("PATH")
69
-	defer os.Setenv("PATH", osPath)
70
-
71 68
 	workingDir, err := os.Getwd()
72 69
 	assert.NilError(c, err)
73 70
 	absolute, err := filepath.Abs(filepath.Join(workingDir, "fixtures", "auth"))
74 71
 	assert.NilError(c, err)
75
-	testPath := fmt.Sprintf("%s%c%s", osPath, filepath.ListSeparator, absolute)
76 72
 
77
-	os.Setenv("PATH", testPath)
73
+	osPath := os.Getenv("PATH")
74
+	testPath := fmt.Sprintf("%s%c%s", osPath, filepath.ListSeparator, absolute)
75
+	c.Setenv("PATH", testPath)
78 76
 
79 77
 	cmd := exec.Command("docker-credential-shell-test", "store")
80 78
 	stdin := bytes.NewReader([]byte(fmt.Sprintf(`{"ServerURL": "https://%s", "Username": "%s", "Secret": "%s"}`, privateRegistryURL, s.reg.Username(), s.reg.Password())))
... ...
@@ -367,16 +367,14 @@ func (s *DockerRegistrySuite) TestPullManifestList(c *testing.T) {
367 367
 
368 368
 // #23100
369 369
 func (s *DockerRegistryAuthHtpasswdSuite) TestPullWithExternalAuthLoginWithScheme(c *testing.T) {
370
-	osPath := os.Getenv("PATH")
371
-	defer os.Setenv("PATH", osPath)
372
-
373 370
 	workingDir, err := os.Getwd()
374 371
 	assert.NilError(c, err)
375 372
 	absolute, err := filepath.Abs(filepath.Join(workingDir, "fixtures", "auth"))
376 373
 	assert.NilError(c, err)
377
-	testPath := fmt.Sprintf("%s%c%s", osPath, filepath.ListSeparator, absolute)
378 374
 
379
-	os.Setenv("PATH", testPath)
375
+	osPath := os.Getenv("PATH")
376
+	testPath := fmt.Sprintf("%s%c%s", osPath, filepath.ListSeparator, absolute)
377
+	c.Setenv("PATH", testPath)
380 378
 
381 379
 	repoName := fmt.Sprintf("%v/dockercli/busybox:authtest", privateRegistryURL)
382 380
 
... ...
@@ -411,16 +409,14 @@ func (s *DockerRegistryAuthHtpasswdSuite) TestPullWithExternalAuthLoginWithSchem
411 411
 }
412 412
 
413 413
 func (s *DockerRegistryAuthHtpasswdSuite) TestPullWithExternalAuth(c *testing.T) {
414
-	osPath := os.Getenv("PATH")
415
-	defer os.Setenv("PATH", osPath)
416
-
417 414
 	workingDir, err := os.Getwd()
418 415
 	assert.NilError(c, err)
419 416
 	absolute, err := filepath.Abs(filepath.Join(workingDir, "fixtures", "auth"))
420 417
 	assert.NilError(c, err)
421
-	testPath := fmt.Sprintf("%s%c%s", osPath, filepath.ListSeparator, absolute)
422 418
 
423
-	os.Setenv("PATH", testPath)
419
+	osPath := os.Getenv("PATH")
420
+	testPath := fmt.Sprintf("%s%c%s", osPath, filepath.ListSeparator, absolute)
421
+	c.Setenv("PATH", testPath)
424 422
 
425 423
 	repoName := fmt.Sprintf("%v/dockercli/busybox:authtest", privateRegistryURL)
426 424
 
... ...
@@ -243,8 +243,6 @@ func (s *DockerSuite) TestRunAttachDetachFromConfig(c *testing.T) {
243 243
 	keyA := []byte{97}
244 244
 
245 245
 	// Setup config
246
-	homeKey := homedir.Key()
247
-	homeVal := homedir.Get()
248 246
 	tmpDir, err := os.MkdirTemp("", "fake-home")
249 247
 	assert.NilError(c, err)
250 248
 	defer os.RemoveAll(tmpDir)
... ...
@@ -253,8 +251,7 @@ func (s *DockerSuite) TestRunAttachDetachFromConfig(c *testing.T) {
253 253
 	os.Mkdir(dotDocker, 0600)
254 254
 	tmpCfg := filepath.Join(dotDocker, "config.json")
255 255
 
256
-	defer func() { os.Setenv(homeKey, homeVal) }()
257
-	os.Setenv(homeKey, tmpDir)
256
+	c.Setenv(homedir.Key(), tmpDir)
258 257
 
259 258
 	data := `{
260 259
 		"detachKeys": "ctrl-a,a"
... ...
@@ -326,8 +323,6 @@ func (s *DockerSuite) TestRunAttachDetachKeysOverrideConfig(c *testing.T) {
326 326
 	keyA := []byte{97}
327 327
 
328 328
 	// Setup config
329
-	homeKey := homedir.Key()
330
-	homeVal := homedir.Get()
331 329
 	tmpDir, err := os.MkdirTemp("", "fake-home")
332 330
 	assert.NilError(c, err)
333 331
 	defer os.RemoveAll(tmpDir)
... ...
@@ -336,8 +331,7 @@ func (s *DockerSuite) TestRunAttachDetachKeysOverrideConfig(c *testing.T) {
336 336
 	os.Mkdir(dotDocker, 0600)
337 337
 	tmpCfg := filepath.Join(dotDocker, "config.json")
338 338
 
339
-	defer func() { os.Setenv(homeKey, homeVal) }()
340
-	os.Setenv(homeKey, tmpDir)
339
+	c.Setenv(homedir.Key(), tmpDir)
341 340
 
342 341
 	data := `{
343 342
 		"detachKeys": "ctrl-e,e"
... ...
@@ -1372,8 +1372,7 @@ func TestDisablePigz(t *testing.T) {
1372 1372
 		t.Log("Test will not check full path when Pigz not installed")
1373 1373
 	}
1374 1374
 
1375
-	os.Setenv("MOBY_DISABLE_PIGZ", "true")
1376
-	defer os.Unsetenv("MOBY_DISABLE_PIGZ")
1375
+	t.Setenv("MOBY_DISABLE_PIGZ", "true")
1377 1376
 
1378 1377
 	r := testDecompressStream(t, "gz", "gzip -f")
1379 1378
 	// For the bufio pool
... ...
@@ -3,7 +3,6 @@ package jsonmessage // import "github.com/docker/docker/pkg/jsonmessage"
3 3
 import (
4 4
 	"bytes"
5 5
 	"fmt"
6
-	"os"
7 6
 	"strings"
8 7
 	"testing"
9 8
 	"time"
... ...
@@ -268,8 +267,7 @@ func TestDisplayJSONMessagesStream(t *testing.T) {
268 268
 
269 269
 	// Use $TERM which is unlikely to exist, forcing DisplayJSONMessageStream to
270 270
 	// (hopefully) use &noTermInfo.
271
-	origTerm := os.Getenv("TERM")
272
-	os.Setenv("TERM", "xyzzy-non-existent-terminfo")
271
+	t.Setenv("TERM", "xyzzy-non-existent-terminfo")
273 272
 
274 273
 	for jsonMessage, expectedMessages := range messages {
275 274
 		data := bytes.NewBuffer([]byte{})
... ...
@@ -294,6 +292,4 @@ func TestDisplayJSONMessagesStream(t *testing.T) {
294 294
 			t.Fatalf("\nExpected %q\n     got %q", expectedMessages[1], data.String())
295 295
 		}
296 296
 	}
297
-	os.Setenv("TERM", origTerm)
298
-
299 297
 }