Browse code

Fix linux mount calls not applying propagation type changes

Propagation type changes must be done as a separate call, in the
same way as read only bind mounts.

To fix this:
1. Ensure propagation type change flags aren't included in other calls.
2. Apply propagation type change in a separate call.

Also:
* Make it clear which parameters are ignored by passing them as empty.
* Add tests to ensure Mount options are applied correctly.

Fixes #30415

Signed-off-by: Steven Hartland <steven.hartland@multiplay.co.uk>

Steven Hartland authored on 2017/01/25 02:44:00
Showing 2 changed files
... ...
@@ -4,15 +4,50 @@ import (
4 4
 	"syscall"
5 5
 )
6 6
 
7
-func mount(device, target, mType string, flag uintptr, data string) error {
8
-	if err := syscall.Mount(device, target, mType, flag, data); err != nil {
9
-		return err
7
+const (
8
+	// ptypes is the set propagation types.
9
+	ptypes = syscall.MS_SHARED | syscall.MS_PRIVATE | syscall.MS_SLAVE | syscall.MS_UNBINDABLE
10
+
11
+	// pflags is the full set valid flags for a change propagation call.
12
+	pflags = ptypes | syscall.MS_REC | syscall.MS_SILENT
13
+
14
+	// broflags is the combination of bind and read only
15
+	broflags = syscall.MS_BIND | syscall.MS_RDONLY
16
+)
17
+
18
+// isremount returns true if either device name or flags identify a remount request, false otherwise.
19
+func isremount(device string, flags uintptr) bool {
20
+	switch {
21
+	// We treat device "" and "none" as a remount request to provide compatibility with
22
+	// requests that don't explicitly set MS_REMOUNT such as those manipulating bind mounts.
23
+	case flags&syscall.MS_REMOUNT != 0, device == "", device == "none":
24
+		return true
25
+	default:
26
+		return false
27
+	}
28
+}
29
+
30
+func mount(device, target, mType string, flags uintptr, data string) error {
31
+	oflags := flags &^ ptypes
32
+	if !isremount(device, flags) {
33
+		// Initial call applying all non-propagation flags.
34
+		if err := syscall.Mount(device, target, mType, oflags, data); err != nil {
35
+			return err
36
+		}
10 37
 	}
11 38
 
12
-	// If we have a bind mount or remount, remount...
13
-	if flag&syscall.MS_BIND == syscall.MS_BIND && flag&syscall.MS_RDONLY == syscall.MS_RDONLY {
14
-		return syscall.Mount(device, target, mType, flag|syscall.MS_REMOUNT, data)
39
+	if flags&ptypes != 0 {
40
+		// Change the propagation type.
41
+		if err := syscall.Mount("", target, "", flags&pflags, ""); err != nil {
42
+			return err
43
+		}
15 44
 	}
45
+
46
+	if oflags&broflags == broflags {
47
+		// Remount the bind to apply read only.
48
+		return syscall.Mount("", target, "", oflags|syscall.MS_REMOUNT, "")
49
+	}
50
+
16 51
 	return nil
17 52
 }
18 53
 
19 54
new file mode 100644
... ...
@@ -0,0 +1,195 @@
0
+// +build linux
1
+
2
+package mount
3
+
4
+import (
5
+	"fmt"
6
+	"io/ioutil"
7
+	"os"
8
+	"strings"
9
+	"testing"
10
+)
11
+
12
+func TestMount(t *testing.T) {
13
+	if os.Getuid() != 0 {
14
+		t.Skip("not root tests would fail")
15
+	}
16
+
17
+	source, err := ioutil.TempDir("", "mount-test-source-")
18
+	if err != nil {
19
+		t.Fatal(err)
20
+	}
21
+	defer os.RemoveAll(source)
22
+
23
+	// Ensure we have a known start point by mounting tmpfs with given options
24
+	if err := Mount("tmpfs", source, "tmpfs", "private"); err != nil {
25
+		t.Fatal(err)
26
+	}
27
+	defer ensureUnmount(t, source)
28
+	validateMount(t, source, "", "")
29
+	if t.Failed() {
30
+		t.FailNow()
31
+	}
32
+
33
+	target, err := ioutil.TempDir("", "mount-test-target-")
34
+	if err != nil {
35
+		t.Fatal(err)
36
+	}
37
+	defer os.RemoveAll(target)
38
+
39
+	tests := []struct {
40
+		source           string
41
+		ftype            string
42
+		options          string
43
+		expectedOpts     string
44
+		expectedOptional string
45
+	}{
46
+		// No options
47
+		{"tmpfs", "tmpfs", "", "", ""},
48
+		// Default rw / ro test
49
+		{source, "", "bind", "", ""},
50
+		{source, "", "bind,private", "", ""},
51
+		{source, "", "bind,shared", "", "shared"},
52
+		{source, "", "bind,slave", "", "master"},
53
+		{source, "", "bind,unbindable", "", "unbindable"},
54
+		// Read Write tests
55
+		{source, "", "bind,rw", "rw", ""},
56
+		{source, "", "bind,rw,private", "rw", ""},
57
+		{source, "", "bind,rw,shared", "rw", "shared"},
58
+		{source, "", "bind,rw,slave", "rw", "master"},
59
+		{source, "", "bind,rw,unbindable", "rw", "unbindable"},
60
+		// Read Only tests
61
+		{source, "", "bind,ro", "ro", ""},
62
+		{source, "", "bind,ro,private", "ro", ""},
63
+		{source, "", "bind,ro,shared", "ro", "shared"},
64
+		{source, "", "bind,ro,slave", "ro", "master"},
65
+		{source, "", "bind,ro,unbindable", "ro", "unbindable"},
66
+	}
67
+
68
+	for _, tc := range tests {
69
+		ftype, options := tc.ftype, tc.options
70
+		if tc.ftype == "" {
71
+			ftype = "none"
72
+		}
73
+		if tc.options == "" {
74
+			options = "none"
75
+		}
76
+
77
+		t.Run(fmt.Sprintf("%v-%v", ftype, options), func(t *testing.T) {
78
+			if strings.Contains(tc.options, "slave") {
79
+				// Slave requires a shared source
80
+				if err := MakeShared(source); err != nil {
81
+					t.Fatal(err)
82
+				}
83
+				defer func() {
84
+					if err := MakePrivate(source); err != nil {
85
+						t.Fatal(err)
86
+					}
87
+				}()
88
+			}
89
+			if err := Mount(tc.source, target, tc.ftype, tc.options); err != nil {
90
+				t.Fatal(err)
91
+			}
92
+			defer ensureUnmount(t, target)
93
+			validateMount(t, target, tc.expectedOpts, tc.expectedOptional)
94
+		})
95
+	}
96
+}
97
+
98
+// ensureUnmount umounts mnt checking for errors
99
+func ensureUnmount(t *testing.T, mnt string) {
100
+	if err := Unmount(mnt); err != nil {
101
+		t.Error(err)
102
+	}
103
+}
104
+
105
+// validateMount checks that mnt has the given options
106
+func validateMount(t *testing.T, mnt string, opts, optional string) {
107
+	info, err := GetMounts()
108
+	if err != nil {
109
+		t.Fatal(err)
110
+	}
111
+
112
+	wantedOpts := make(map[string]struct{})
113
+	if opts != "" {
114
+		for _, opt := range strings.Split(opts, ",") {
115
+			wantedOpts[opt] = struct{}{}
116
+		}
117
+	}
118
+
119
+	wantedOptional := make(map[string]struct{})
120
+	if optional != "" {
121
+		for _, opt := range strings.Split(optional, ",") {
122
+			wantedOptional[opt] = struct{}{}
123
+		}
124
+	}
125
+
126
+	mnts := make(map[int]*Info, len(info))
127
+	for _, mi := range info {
128
+		mnts[mi.ID] = mi
129
+	}
130
+
131
+	for _, mi := range info {
132
+		if mi.Mountpoint != mnt {
133
+			continue
134
+		}
135
+
136
+		// Use parent info as the defaults
137
+		p := mnts[mi.Parent]
138
+		pOpts := make(map[string]struct{})
139
+		if p.Opts != "" {
140
+			for _, opt := range strings.Split(p.Opts, ",") {
141
+				pOpts[clean(opt)] = struct{}{}
142
+			}
143
+		}
144
+		pOptional := make(map[string]struct{})
145
+		if p.Optional != "" {
146
+			for _, field := range strings.Split(p.Optional, ",") {
147
+				pOptional[clean(field)] = struct{}{}
148
+			}
149
+		}
150
+
151
+		// Validate Opts
152
+		if mi.Opts != "" {
153
+			for _, opt := range strings.Split(mi.Opts, ",") {
154
+				opt = clean(opt)
155
+				if !has(wantedOpts, opt) && !has(pOpts, opt) {
156
+					t.Errorf("unexpected mount option %q expected %q", opt, opts)
157
+				}
158
+				delete(wantedOpts, opt)
159
+			}
160
+		}
161
+		for opt := range wantedOpts {
162
+			t.Errorf("missing mount option %q found %q", opt, mi.Opts)
163
+		}
164
+
165
+		// Validate Optional
166
+		if mi.Optional != "" {
167
+			for _, field := range strings.Split(mi.Optional, ",") {
168
+				field = clean(field)
169
+				if !has(wantedOptional, field) && !has(pOptional, field) {
170
+					t.Errorf("unexpected optional failed %q expected %q", field, optional)
171
+				}
172
+				delete(wantedOptional, field)
173
+			}
174
+		}
175
+		for field := range wantedOptional {
176
+			t.Errorf("missing optional field %q found %q", field, mi.Optional)
177
+		}
178
+
179
+		return
180
+	}
181
+
182
+	t.Errorf("failed to find mount %q", mnt)
183
+}
184
+
185
+// clean strips off any value param after the colon
186
+func clean(v string) string {
187
+	return strings.SplitN(v, ":", 2)[0]
188
+}
189
+
190
+// has returns true if key is a member of m
191
+func has(m map[string]struct{}, key string) bool {
192
+	_, ok := m[key]
193
+	return ok
194
+}