Browse code

opts: simplify ValidateEnv to use os.LookupEnv

os.LookupEnv() was not available yet at the time that this was
implemented (9ab73260f8e4662e7321b257c636928892f023cf), but now
provides the functionality we need, so replacing our custom handling.

Signed-off-by: Sebastiaan van Stijn <github@gone.nl>

Sebastiaan van Stijn authored on 2020/08/14 23:54:08
Showing 2 changed files
... ...
@@ -1,48 +1,30 @@
1 1
 package opts // import "github.com/docker/docker/opts"
2 2
 
3 3
 import (
4
-	"fmt"
5 4
 	"os"
6
-	"runtime"
7 5
 	"strings"
8 6
 
9 7
 	"github.com/pkg/errors"
10 8
 )
11 9
 
12 10
 // ValidateEnv validates an environment variable and returns it.
13
-// If no value is specified, it returns the current value using os.Getenv.
11
+// If no value is specified, it obtains its value from the current environment
14 12
 //
15 13
 // As on ParseEnvFile and related to #16585, environment variable names
16
-// are not validate what so ever, it's up to application inside docker
14
+// are not validate whatsoever, it's up to application inside docker
17 15
 // to validate them or not.
18 16
 //
19 17
 // The only validation here is to check if name is empty, per #25099
20 18
 func ValidateEnv(val string) (string, error) {
21
-	arr := strings.Split(val, "=")
19
+	arr := strings.SplitN(val, "=", 2)
22 20
 	if arr[0] == "" {
23
-		return "", errors.Errorf("invalid environment variable: %s", val)
21
+		return "", errors.New("invalid environment variable: " + val)
24 22
 	}
25 23
 	if len(arr) > 1 {
26 24
 		return val, nil
27 25
 	}
28
-	if !doesEnvExist(val) {
29
-		return val, nil
30
-	}
31
-	return fmt.Sprintf("%s=%s", val, os.Getenv(val)), nil
32
-}
33
-
34
-func doesEnvExist(name string) bool {
35
-	for _, entry := range os.Environ() {
36
-		parts := strings.SplitN(entry, "=", 2)
37
-		if runtime.GOOS == "windows" {
38
-			// Environment variable are case-insensitive on Windows. PaTh, path and PATH are equivalent.
39
-			if strings.EqualFold(parts[0], name) {
40
-				return true
41
-			}
42
-		}
43
-		if parts[0] == name {
44
-			return true
45
-		}
26
+	if envVal, ok := os.LookupEnv(arr[0]); ok {
27
+		return arr[0] + "=" + envVal, nil
46 28
 	}
47
-	return false
29
+	return val, nil
48 30
 }
... ...
@@ -5,14 +5,17 @@ import (
5 5
 	"os"
6 6
 	"runtime"
7 7
 	"testing"
8
+
9
+	"gotest.tools/v3/assert"
8 10
 )
9 11
 
10 12
 func TestValidateEnv(t *testing.T) {
11
-	testcase := []struct {
13
+	type testCase struct {
12 14
 		value    string
13 15
 		expected string
14 16
 		err      error
15
-	}{
17
+	}
18
+	tests := []testCase{
16 19
 		{
17 20
 			value:    "a",
18 21
 			expected: "a",
... ...
@@ -51,7 +54,11 @@ func TestValidateEnv(t *testing.T) {
51 51
 		},
52 52
 		{
53 53
 			value: "=a",
54
-			err:   fmt.Errorf(fmt.Sprintf("invalid environment variable: %s", "=a")),
54
+			err:   fmt.Errorf("invalid environment variable: =a"),
55
+		},
56
+		{
57
+			value:    "PATH=",
58
+			expected: "PATH=",
55 59
 		},
56 60
 		{
57 61
 			value:    "PATH=something",
... ...
@@ -83,42 +90,30 @@ func TestValidateEnv(t *testing.T) {
83 83
 		},
84 84
 		{
85 85
 			value: "=",
86
-			err:   fmt.Errorf(fmt.Sprintf("invalid environment variable: %s", "=")),
86
+			err:   fmt.Errorf("invalid environment variable: ="),
87 87
 		},
88 88
 	}
89 89
 
90
-	// Environment variables are case in-sensitive on Windows
91 90
 	if runtime.GOOS == "windows" {
92
-		tmp := struct {
93
-			value    string
94
-			expected string
95
-			err      error
96
-		}{
91
+		// Environment variables are case in-sensitive on Windows
92
+		tests = append(tests, testCase{
97 93
 			value:    "PaTh",
98 94
 			expected: fmt.Sprintf("PaTh=%v", os.Getenv("PATH")),
99 95
 			err:      nil,
100
-		}
101
-		testcase = append(testcase, tmp)
96
+		})
102 97
 	}
103 98
 
104
-	for _, r := range testcase {
105
-		actual, err := ValidateEnv(r.value)
99
+	for _, tc := range tests {
100
+		tc := tc
101
+		t.Run(tc.value, func(t *testing.T) {
102
+			actual, err := ValidateEnv(tc.value)
106 103
 
107
-		if err != nil {
108
-			if r.err == nil {
109
-				t.Fatalf("Expected err is nil, got err[%v]", err)
104
+			if tc.err == nil {
105
+				assert.NilError(t, err)
106
+			} else {
107
+				assert.Error(t, err, tc.err.Error())
110 108
 			}
111
-			if err.Error() != r.err.Error() {
112
-				t.Fatalf("Expected err[%v], got err[%v]", r.err, err)
113
-			}
114
-		}
115
-
116
-		if err == nil && r.err != nil {
117
-			t.Fatalf("Expected err[%v], but err is nil", r.err)
118
-		}
119
-
120
-		if actual != r.expected {
121
-			t.Fatalf("Expected [%v], got [%v]", r.expected, actual)
122
-		}
109
+			assert.Equal(t, actual, tc.expected)
110
+		})
123 111
 	}
124 112
 }