Procházet zdrojové kódy

getSourceMount(): simplify

The flow of getSourceMount was:
 1 get all entries from /proc/self/mountinfo
 2 do a linear search for the `source` directory
 3 if found, return its data
 4 get the parent directory of `source`, goto 2

The repeated linear search through the whole mountinfo (which can have
thousands of records) is inefficient. Instead, let's just

 1 collect all the relevant records (only those mount points
   that can be a parent of `source`)
 2 find the record with the longest mountpath, return its data

This was tested manually with something like

```go
func TestGetSourceMount(t *testing.T) {
	mnt, flags, err := getSourceMount("/sys/devices/msr/")
	assert.NoError(t, err)
	t.Logf("mnt: %v, flags: %v", mnt, flags)
}
```

...but it relies on having a specific mount points on the system
being used for testing.

[v2: add unit tests for ParentsFilter]

Signed-off-by: Kir Kolyshkin <kolyshkin@gmail.com>
Kir Kolyshkin před 7 roky
rodič
revize
871c957242
3 změnil soubory, kde provedl 30 přidání a 25 odebrání
  1. 13 25
      daemon/oci_linux.go
  2. 11 0
      pkg/mount/mount.go
  3. 6 0
      pkg/mount/mountinfo_linux_test.go

+ 13 - 25
daemon/oci_linux.go

@@ -380,15 +380,6 @@ func specMapping(s []idtools.IDMap) []specs.LinuxIDMapping {
 	return ids
 }
 
-func getMountInfo(mountinfo []*mount.Info, dir string) *mount.Info {
-	for _, m := range mountinfo {
-		if m.Mountpoint == dir {
-			return m
-		}
-	}
-	return nil
-}
-
 // Get the source mount point of directory passed in as argument. Also return
 // optional fields.
 func getSourceMount(source string) (string, string, error) {
@@ -398,29 +389,26 @@ func getSourceMount(source string) (string, string, error) {
 		return "", "", err
 	}
 
-	mountinfos, err := mount.GetMounts(nil)
+	mi, err := mount.GetMounts(mount.ParentsFilter(sourcePath))
 	if err != nil {
 		return "", "", err
 	}
-
-	mountinfo := getMountInfo(mountinfos, sourcePath)
-	if mountinfo != nil {
-		return sourcePath, mountinfo.Optional, nil
+	if len(mi) < 1 {
+		return "", "", fmt.Errorf("Can't find mount point of %s", source)
 	}
 
-	path := sourcePath
-	for {
-		path = filepath.Dir(path)
-
-		mountinfo = getMountInfo(mountinfos, path)
-		if mountinfo != nil {
-			return path, mountinfo.Optional, nil
-		}
-
-		if path == "/" {
-			break
+	// find the longest mount point
+	var idx, maxlen int
+	for i := range mi {
+		if len(mi[i].Mountpoint) > maxlen {
+			maxlen = len(mi[i].Mountpoint)
+			idx = i
 		}
 	}
+	// and return it unless it's "/"
+	if mi[idx].Mountpoint != "/" {
+		return mi[idx].Mountpoint, mi[idx].Optional, nil
+	}
 
 	// If we are here, we did not find parent mount. Something is wrong.
 	return "", "", fmt.Errorf("Could not find source mount of %s", source)

+ 11 - 0
pkg/mount/mount.go

@@ -36,6 +36,17 @@ func SingleEntryFilter(mp string) FilterFunc {
 	}
 }
 
+// ParentsFilter returns all entries whose mount points
+// can be parents of a path specified, discarding others.
+// For example, given `/var/lib/docker/something`, entries
+// like `/var/lib/docker`, `/var` and `/` are returned.
+func ParentsFilter(path string) FilterFunc {
+	return func(m *Info) (bool, bool) {
+		skip := !strings.HasPrefix(path, m.Mountpoint)
+		return skip, false
+	}
+}
+
 // GetMounts retrieves a list of mounts for the current running process,
 // with an optional filter applied (use nil for no filter).
 func GetMounts(f FilterFunc) ([]*Info, error) {

+ 6 - 0
pkg/mount/mountinfo_linux_test.go

@@ -499,4 +499,10 @@ func TestParseMountinfoFilters(t *testing.T) {
 	infos, err = parseInfoFile(r, PrefixFilter("nonexistent"))
 	assert.NilError(t, err)
 	assert.Equal(t, 0, len(infos))
+
+	r.Reset([]byte(fedoraMountinfo))
+	infos, err = parseInfoFile(r, ParentsFilter("/sys/fs/cgroup/cpu,cpuacct"))
+	assert.NilError(t, err)
+	// there should be 4 results returned: /sys/fs/cgroup/cpu,cpuacct /sys/fs/cgroup /sys /
+	assert.Equal(t, 4, len(infos))
 }