Browse code

getSourceMount(): simplify

The flow of getSourceMount was:
1 get all entries from /proc/self/mountinfo
2 do a linear search for the `source` directory
3 if found, return its data
4 get the parent directory of `source`, goto 2

The repeated linear search through the whole mountinfo (which can have
thousands of records) is inefficient. Instead, let's just

1 collect all the relevant records (only those mount points
that can be a parent of `source`)
2 find the record with the longest mountpath, return its data

This was tested manually with something like

```go
func TestGetSourceMount(t *testing.T) {
mnt, flags, err := getSourceMount("/sys/devices/msr/")
assert.NoError(t, err)
t.Logf("mnt: %v, flags: %v", mnt, flags)
}
```

...but it relies on having a specific mount points on the system
being used for testing.

[v2: add unit tests for ParentsFilter]

Signed-off-by: Kir Kolyshkin <kolyshkin@gmail.com>

Kir Kolyshkin authored on 2018/01/26 13:13:46
Showing 3 changed files
... ...
@@ -380,15 +380,6 @@ func specMapping(s []idtools.IDMap) []specs.LinuxIDMapping {
380 380
 	return ids
381 381
 }
382 382
 
383
-func getMountInfo(mountinfo []*mount.Info, dir string) *mount.Info {
384
-	for _, m := range mountinfo {
385
-		if m.Mountpoint == dir {
386
-			return m
387
-		}
388
-	}
389
-	return nil
390
-}
391
-
392 383
 // Get the source mount point of directory passed in as argument. Also return
393 384
 // optional fields.
394 385
 func getSourceMount(source string) (string, string, error) {
... ...
@@ -398,29 +389,26 @@ func getSourceMount(source string) (string, string, error) {
398 398
 		return "", "", err
399 399
 	}
400 400
 
401
-	mountinfos, err := mount.GetMounts(nil)
401
+	mi, err := mount.GetMounts(mount.ParentsFilter(sourcePath))
402 402
 	if err != nil {
403 403
 		return "", "", err
404 404
 	}
405
-
406
-	mountinfo := getMountInfo(mountinfos, sourcePath)
407
-	if mountinfo != nil {
408
-		return sourcePath, mountinfo.Optional, nil
405
+	if len(mi) < 1 {
406
+		return "", "", fmt.Errorf("Can't find mount point of %s", source)
409 407
 	}
410 408
 
411
-	path := sourcePath
412
-	for {
413
-		path = filepath.Dir(path)
414
-
415
-		mountinfo = getMountInfo(mountinfos, path)
416
-		if mountinfo != nil {
417
-			return path, mountinfo.Optional, nil
418
-		}
419
-
420
-		if path == "/" {
421
-			break
409
+	// find the longest mount point
410
+	var idx, maxlen int
411
+	for i := range mi {
412
+		if len(mi[i].Mountpoint) > maxlen {
413
+			maxlen = len(mi[i].Mountpoint)
414
+			idx = i
422 415
 		}
423 416
 	}
417
+	// and return it unless it's "/"
418
+	if mi[idx].Mountpoint != "/" {
419
+		return mi[idx].Mountpoint, mi[idx].Optional, nil
420
+	}
424 421
 
425 422
 	// If we are here, we did not find parent mount. Something is wrong.
426 423
 	return "", "", fmt.Errorf("Could not find source mount of %s", source)
... ...
@@ -36,6 +36,17 @@ func SingleEntryFilter(mp string) FilterFunc {
36 36
 	}
37 37
 }
38 38
 
39
+// ParentsFilter returns all entries whose mount points
40
+// can be parents of a path specified, discarding others.
41
+// For example, given `/var/lib/docker/something`, entries
42
+// like `/var/lib/docker`, `/var` and `/` are returned.
43
+func ParentsFilter(path string) FilterFunc {
44
+	return func(m *Info) (bool, bool) {
45
+		skip := !strings.HasPrefix(path, m.Mountpoint)
46
+		return skip, false
47
+	}
48
+}
49
+
39 50
 // GetMounts retrieves a list of mounts for the current running process,
40 51
 // with an optional filter applied (use nil for no filter).
41 52
 func GetMounts(f FilterFunc) ([]*Info, error) {
... ...
@@ -499,4 +499,10 @@ func TestParseMountinfoFilters(t *testing.T) {
499 499
 	infos, err = parseInfoFile(r, PrefixFilter("nonexistent"))
500 500
 	assert.NilError(t, err)
501 501
 	assert.Equal(t, 0, len(infos))
502
+
503
+	r.Reset([]byte(fedoraMountinfo))
504
+	infos, err = parseInfoFile(r, ParentsFilter("/sys/fs/cgroup/cpu,cpuacct"))
505
+	assert.NilError(t, err)
506
+	// there should be 4 results returned: /sys/fs/cgroup/cpu,cpuacct /sys/fs/cgroup /sys /
507
+	assert.Equal(t, 4, len(infos))
502 508
 }