Browse code

pkg/mount: implement/use filter for mountinfo parsing

Functions `GetMounts()` and `parseMountTable()` return all the entries
as read and parsed from /proc/self/mountinfo. In many cases the caller
is only interested only one or a few entries, not all of them.

One good example is `Mounted()` function, which looks for a specific
entry only. Another example is `RecursiveUnmount()` which is only
interested in mount under a specific path.

This commit adds `filter` argument to `GetMounts()` to implement
two things:
1. filter out entries a caller is not interested in
2. stop processing if a caller is found what it wanted

`nil` can be passed to get a backward-compatible behavior, i.e. return
all the entries.

A few filters are implemented:
- `PrefixFilter`: filters out all entries not under `prefix`
- `SingleEntryFilter`: looks for a specific entry

Finally, `Mounted()` is modified to use `SingleEntryFilter()`, and
`RecursiveUnmount()` is using `PrefixFilter()`.

Unit tests are added to check filters are working.

[v2: ditch NoFilter, use nil]
[v3: ditch GetMountsFiltered()]
[v4: add unit test for filters]
[v5: switch to gotestyourself]

Signed-off-by: Kir Kolyshkin <kolyshkin@gmail.com>

Kir Kolyshkin authored on 2018/01/20 04:56:32
Showing 14 changed files
... ...
@@ -74,7 +74,7 @@ func (daemon *Daemon) cleanupMounts() error {
74 74
 		return err
75 75
 	}
76 76
 
77
-	infos, err := mount.GetMounts()
77
+	infos, err := mount.GetMounts(nil)
78 78
 	if err != nil {
79 79
 		return errors.Wrap(err, "error reading mount table for cleanup")
80 80
 	}
... ...
@@ -145,7 +145,7 @@ func lookupZfsDataset(rootdir string) (string, error) {
145 145
 	}
146 146
 	wantedDev := stat.Dev
147 147
 
148
-	mounts, err := mount.GetMounts()
148
+	mounts, err := mount.GetMounts(nil)
149 149
 	if err != nil {
150 150
 		return "", err
151 151
 	}
... ...
@@ -398,7 +398,7 @@ func getSourceMount(source string) (string, string, error) {
398 398
 		return "", "", err
399 399
 	}
400 400
 
401
-	mountinfos, err := mount.GetMounts()
401
+	mountinfos, err := mount.GetMounts(nil)
402 402
 	if err != nil {
403 403
 		return "", "", err
404 404
 	}
... ...
@@ -260,7 +260,7 @@ func (s *DockerDaemonSuite) TestPluginVolumeRemoveOnRestart(c *check.C) {
260 260
 }
261 261
 
262 262
 func existsMountpointWithPrefix(mountpointPrefix string) (bool, error) {
263
-	mounts, err := mount.GetMounts()
263
+	mounts, err := mount.GetMounts(nil)
264 264
 	if err != nil {
265 265
 		return false, err
266 266
 	}
... ...
@@ -9,26 +9,48 @@ import (
9 9
 	"github.com/sirupsen/logrus"
10 10
 )
11 11
 
12
-// GetMounts retrieves a list of mounts for the current running process.
13
-func GetMounts() ([]*Info, error) {
14
-	return parseMountTable()
12
+// FilterFunc is a type defining a callback function
13
+// to filter out unwanted entries. It takes a pointer
14
+// to an Info struct (not fully populated, currently
15
+// only Mountpoint is filled in), and returns two booleans:
16
+//  - skip: true if the entry should be skipped
17
+//  - stop: true if parsing should be stopped after the entry
18
+type FilterFunc func(*Info) (skip, stop bool)
19
+
20
+// PrefixFilter discards all entries whose mount points
21
+// do not start with a prefix specified
22
+func PrefixFilter(prefix string) FilterFunc {
23
+	return func(m *Info) (bool, bool) {
24
+		skip := !strings.HasPrefix(m.Mountpoint, prefix)
25
+		return skip, false
26
+	}
27
+}
28
+
29
+// SingleEntryFilter looks for a specific entry
30
+func SingleEntryFilter(mp string) FilterFunc {
31
+	return func(m *Info) (bool, bool) {
32
+		if m.Mountpoint == mp {
33
+			return false, true // don't skip, stop now
34
+		}
35
+		return true, false // skip, keep going
36
+	}
37
+}
38
+
39
+// GetMounts retrieves a list of mounts for the current running process,
40
+// with an optional filter applied (use nil for no filter).
41
+func GetMounts(f FilterFunc) ([]*Info, error) {
42
+	return parseMountTable(f)
15 43
 }
16 44
 
17 45
 // Mounted determines if a specified mountpoint has been mounted.
18 46
 // On Linux it looks at /proc/self/mountinfo.
19 47
 func Mounted(mountpoint string) (bool, error) {
20
-	entries, err := parseMountTable()
48
+	entries, err := GetMounts(SingleEntryFilter(mountpoint))
21 49
 	if err != nil {
22 50
 		return false, err
23 51
 	}
24 52
 
25
-	// Search the table for the mountpoint
26
-	for _, e := range entries {
27
-		if e.Mountpoint == mountpoint {
28
-			return true, nil
29
-		}
30
-	}
31
-	return false, nil
53
+	return len(entries) > 0, nil
32 54
 }
33 55
 
34 56
 // Mount will mount filesystem according to the specified configuration, on the
... ...
@@ -66,7 +88,7 @@ func Unmount(target string) error {
66 66
 // RecursiveUnmount unmounts the target and all mounts underneath, starting with
67 67
 // the deepsest mount first.
68 68
 func RecursiveUnmount(target string) error {
69
-	mounts, err := GetMounts()
69
+	mounts, err := parseMountTable(PrefixFilter(target))
70 70
 	if err != nil {
71 71
 		return err
72 72
 	}
... ...
@@ -77,9 +99,6 @@ func RecursiveUnmount(target string) error {
77 77
 	})
78 78
 
79 79
 	for i, m := range mounts {
80
-		if !strings.HasPrefix(m.Mountpoint, target) {
81
-			continue
82
-		}
83 80
 		logrus.Debugf("Trying to unmount %s", m.Mountpoint)
84 81
 		err = unmount(m.Mountpoint, mntDetach)
85 82
 		if err != nil {
... ...
@@ -129,7 +129,7 @@ func TestMountReadonly(t *testing.T) {
129 129
 }
130 130
 
131 131
 func TestGetMounts(t *testing.T) {
132
-	mounts, err := GetMounts()
132
+	mounts, err := GetMounts(nil)
133 133
 	if err != nil {
134 134
 		t.Fatal(err)
135 135
 	}
... ...
@@ -121,7 +121,7 @@ func ensureUnmount(t *testing.T, mnt string) {
121 121
 
122 122
 // validateMount checks that mnt has the given options
123 123
 func validateMount(t *testing.T, mnt string, opts, optional, vfs string) {
124
-	info, err := GetMounts()
124
+	info, err := GetMounts(nil)
125 125
 	if err != nil {
126 126
 		t.Fatal(err)
127 127
 	}
... ...
@@ -15,7 +15,7 @@ import (
15 15
 
16 16
 // Parse /proc/self/mountinfo because comparing Dev and ino does not work from
17 17
 // bind mounts.
18
-func parseMountTable() ([]*Info, error) {
18
+func parseMountTable(filter FilterFunc) ([]*Info, error) {
19 19
 	var rawEntries *C.struct_statfs
20 20
 
21 21
 	count := int(C.getmntinfo(&rawEntries, C.MNT_WAIT))
... ...
@@ -32,10 +32,24 @@ func parseMountTable() ([]*Info, error) {
32 32
 	var out []*Info
33 33
 	for _, entry := range entries {
34 34
 		var mountinfo Info
35
+		var skip, stop bool
35 36
 		mountinfo.Mountpoint = C.GoString(&entry.f_mntonname[0])
37
+
38
+		if filter != nil {
39
+			// filter out entries we're not interested in
40
+			skip, stop = filter(p)
41
+			if skip {
42
+				continue
43
+			}
44
+		}
45
+
36 46
 		mountinfo.Source = C.GoString(&entry.f_mntfromname[0])
37 47
 		mountinfo.Fstype = C.GoString(&entry.f_fstypename[0])
48
+
38 49
 		out = append(out, &mountinfo)
50
+		if stop {
51
+			break
52
+		}
39 53
 	}
40 54
 	return out, nil
41 55
 }
... ...
@@ -28,17 +28,17 @@ const (
28 28
 
29 29
 // Parse /proc/self/mountinfo because comparing Dev and ino does not work from
30 30
 // bind mounts
31
-func parseMountTable() ([]*Info, error) {
31
+func parseMountTable(filter FilterFunc) ([]*Info, error) {
32 32
 	f, err := os.Open("/proc/self/mountinfo")
33 33
 	if err != nil {
34 34
 		return nil, err
35 35
 	}
36 36
 	defer f.Close()
37 37
 
38
-	return parseInfoFile(f)
38
+	return parseInfoFile(f, filter)
39 39
 }
40 40
 
41
-func parseInfoFile(r io.Reader) ([]*Info, error) {
41
+func parseInfoFile(r io.Reader, filter FilterFunc) ([]*Info, error) {
42 42
 	var (
43 43
 		s   = bufio.NewScanner(r)
44 44
 		out = []*Info{}
... ...
@@ -53,6 +53,7 @@ func parseInfoFile(r io.Reader) ([]*Info, error) {
53 53
 			p              = &Info{}
54 54
 			text           = s.Text()
55 55
 			optionalFields string
56
+			skip, stop     bool
56 57
 		)
57 58
 
58 59
 		if _, err := fmt.Sscanf(text, mountinfoFormat,
... ...
@@ -60,6 +61,13 @@ func parseInfoFile(r io.Reader) ([]*Info, error) {
60 60
 			&p.Root, &p.Mountpoint, &p.Opts, &optionalFields); err != nil {
61 61
 			return nil, fmt.Errorf("Scanning '%s' failed: %s", text, err)
62 62
 		}
63
+		if filter != nil {
64
+			// filter out entries we're not interested in
65
+			skip, stop = filter(p)
66
+			if skip {
67
+				continue
68
+			}
69
+		}
63 70
 		// Safe as mountinfo encodes mountpoints with spaces as \040.
64 71
 		index := strings.Index(text, " - ")
65 72
 		postSeparatorFields := strings.Fields(text[index+3:])
... ...
@@ -75,6 +83,9 @@ func parseInfoFile(r io.Reader) ([]*Info, error) {
75 75
 		p.Source = postSeparatorFields[1]
76 76
 		p.VfsOpts = strings.Join(postSeparatorFields[2:], " ")
77 77
 		out = append(out, p)
78
+		if stop {
79
+			break
80
+		}
78 81
 	}
79 82
 	return out, nil
80 83
 }
... ...
@@ -89,5 +100,5 @@ func PidMountInfo(pid int) ([]*Info, error) {
89 89
 	}
90 90
 	defer f.Close()
91 91
 
92
-	return parseInfoFile(f)
92
+	return parseInfoFile(f, nil)
93 93
 }
... ...
@@ -5,6 +5,8 @@ package mount // import "github.com/docker/docker/pkg/mount"
5 5
 import (
6 6
 	"bytes"
7 7
 	"testing"
8
+
9
+	"github.com/gotestyourself/gotestyourself/assert"
8 10
 )
9 11
 
10 12
 const (
... ...
@@ -424,7 +426,7 @@ const (
424 424
 
425 425
 func TestParseFedoraMountinfo(t *testing.T) {
426 426
 	r := bytes.NewBuffer([]byte(fedoraMountinfo))
427
-	_, err := parseInfoFile(r)
427
+	_, err := parseInfoFile(r, nil)
428 428
 	if err != nil {
429 429
 		t.Fatal(err)
430 430
 	}
... ...
@@ -432,7 +434,7 @@ func TestParseFedoraMountinfo(t *testing.T) {
432 432
 
433 433
 func TestParseUbuntuMountinfo(t *testing.T) {
434 434
 	r := bytes.NewBuffer([]byte(ubuntuMountInfo))
435
-	_, err := parseInfoFile(r)
435
+	_, err := parseInfoFile(r, nil)
436 436
 	if err != nil {
437 437
 		t.Fatal(err)
438 438
 	}
... ...
@@ -440,7 +442,7 @@ func TestParseUbuntuMountinfo(t *testing.T) {
440 440
 
441 441
 func TestParseGentooMountinfo(t *testing.T) {
442 442
 	r := bytes.NewBuffer([]byte(gentooMountinfo))
443
-	_, err := parseInfoFile(r)
443
+	_, err := parseInfoFile(r, nil)
444 444
 	if err != nil {
445 445
 		t.Fatal(err)
446 446
 	}
... ...
@@ -448,7 +450,7 @@ func TestParseGentooMountinfo(t *testing.T) {
448 448
 
449 449
 func TestParseFedoraMountinfoFields(t *testing.T) {
450 450
 	r := bytes.NewBuffer([]byte(fedoraMountinfo))
451
-	infos, err := parseInfoFile(r)
451
+	infos, err := parseInfoFile(r, nil)
452 452
 	if err != nil {
453 453
 		t.Fatal(err)
454 454
 	}
... ...
@@ -474,3 +476,27 @@ func TestParseFedoraMountinfoFields(t *testing.T) {
474 474
 		t.Fatalf("expected %#v, got %#v", mi, infos[0])
475 475
 	}
476 476
 }
477
+
478
+func TestParseMountinfoFilters(t *testing.T) {
479
+	r := bytes.NewReader([]byte(fedoraMountinfo))
480
+
481
+	infos, err := parseInfoFile(r, SingleEntryFilter("/sys/fs/cgroup"))
482
+	assert.NilError(t, err)
483
+	assert.Equal(t, 1, len(infos))
484
+
485
+	r.Reset([]byte(fedoraMountinfo))
486
+	infos, err = parseInfoFile(r, SingleEntryFilter("nonexistent"))
487
+	assert.NilError(t, err)
488
+	assert.Equal(t, 0, len(infos))
489
+
490
+	r.Reset([]byte(fedoraMountinfo))
491
+	infos, err = parseInfoFile(r, PrefixFilter("/sys"))
492
+	assert.NilError(t, err)
493
+	// there are 18 entries starting with /sys in fedoraMountinfo
494
+	assert.Equal(t, 18, len(infos))
495
+
496
+	r.Reset([]byte(fedoraMountinfo))
497
+	infos, err = parseInfoFile(r, PrefixFilter("nonexistent"))
498
+	assert.NilError(t, err)
499
+	assert.Equal(t, 0, len(infos))
500
+}
... ...
@@ -7,6 +7,6 @@ import (
7 7
 	"runtime"
8 8
 )
9 9
 
10
-func parseMountTable() ([]*Info, error) {
10
+func parseMountTable(f FilterFunc) ([]*Info, error) {
11 11
 	return nil, fmt.Errorf("mount.parseMountTable is not implemented on %s/%s", runtime.GOOS, runtime.GOARCH)
12 12
 }
... ...
@@ -1,6 +1,6 @@
1 1
 package mount // import "github.com/docker/docker/pkg/mount"
2 2
 
3
-func parseMountTable() ([]*Info, error) {
3
+func parseMountTable(f FilterFunc) ([]*Info, error) {
4 4
 	// Do NOT return an error!
5 5
 	return nil, nil
6 6
 }
... ...
@@ -66,7 +66,7 @@ func New(scope string, rootIDs idtools.IDPair) (*Root, error) {
66 66
 		return nil, err
67 67
 	}
68 68
 
69
-	mountInfos, err := mount.GetMounts()
69
+	mountInfos, err := mount.GetMounts(nil)
70 70
 	if err != nil {
71 71
 		logrus.Debugf("error looking up mounts for local volume cleanup: %v", err)
72 72
 	}
... ...
@@ -215,7 +215,7 @@ func TestCreateWithOpts(t *testing.T) {
215 215
 		}
216 216
 	}()
217 217
 
218
-	mountInfos, err := mount.GetMounts()
218
+	mountInfos, err := mount.GetMounts(nil)
219 219
 	if err != nil {
220 220
 		t.Fatal(err)
221 221
 	}