Procházet zdrojové kódy

pkg/mount: implement/use filter for mountinfo parsing

Functions `GetMounts()` and `parseMountTable()` return all the entries
as read and parsed from /proc/self/mountinfo. In many cases the caller
is only interested only one or a few entries, not all of them.

One good example is `Mounted()` function, which looks for a specific
entry only. Another example is `RecursiveUnmount()` which is only
interested in mount under a specific path.

This commit adds `filter` argument to `GetMounts()` to implement
two things:
 1. filter out entries a caller is not interested in
 2. stop processing if a caller is found what it wanted

`nil` can be passed to get a backward-compatible behavior, i.e. return
all the entries.

A few filters are implemented:
 - `PrefixFilter`: filters out all entries not under `prefix`
 - `SingleEntryFilter`: looks for a specific entry

Finally, `Mounted()` is modified to use `SingleEntryFilter()`, and
`RecursiveUnmount()` is using `PrefixFilter()`.

Unit tests are added to check filters are working.

[v2: ditch NoFilter, use nil]
[v3: ditch GetMountsFiltered()]
[v4: add unit test for filters]
[v5: switch to gotestyourself]

Signed-off-by: Kir Kolyshkin <kolyshkin@gmail.com>
Kir Kolyshkin před 7 roky
rodič
revize
bb934c6aca

+ 1 - 1
daemon/daemon_linux.go

@@ -74,7 +74,7 @@ func (daemon *Daemon) cleanupMounts() error {
 		return err
 	}
 
-	infos, err := mount.GetMounts()
+	infos, err := mount.GetMounts(nil)
 	if err != nil {
 		return errors.Wrap(err, "error reading mount table for cleanup")
 	}

+ 1 - 1
daemon/graphdriver/zfs/zfs.go

@@ -145,7 +145,7 @@ func lookupZfsDataset(rootdir string) (string, error) {
 	}
 	wantedDev := stat.Dev
 
-	mounts, err := mount.GetMounts()
+	mounts, err := mount.GetMounts(nil)
 	if err != nil {
 		return "", err
 	}

+ 1 - 1
daemon/oci_linux.go

@@ -398,7 +398,7 @@ func getSourceMount(source string) (string, string, error) {
 		return "", "", err
 	}
 
-	mountinfos, err := mount.GetMounts()
+	mountinfos, err := mount.GetMounts(nil)
 	if err != nil {
 		return "", "", err
 	}

+ 1 - 1
integration-cli/docker_cli_daemon_plugins_test.go

@@ -260,7 +260,7 @@ func (s *DockerDaemonSuite) TestPluginVolumeRemoveOnRestart(c *check.C) {
 }
 
 func existsMountpointWithPrefix(mountpointPrefix string) (bool, error) {
-	mounts, err := mount.GetMounts()
+	mounts, err := mount.GetMounts(nil)
 	if err != nil {
 		return false, err
 	}

+ 34 - 15
pkg/mount/mount.go

@@ -9,26 +9,48 @@ import (
 	"github.com/sirupsen/logrus"
 )
 
-// GetMounts retrieves a list of mounts for the current running process.
-func GetMounts() ([]*Info, error) {
-	return parseMountTable()
+// FilterFunc is a type defining a callback function
+// to filter out unwanted entries. It takes a pointer
+// to an Info struct (not fully populated, currently
+// only Mountpoint is filled in), and returns two booleans:
+//  - skip: true if the entry should be skipped
+//  - stop: true if parsing should be stopped after the entry
+type FilterFunc func(*Info) (skip, stop bool)
+
+// PrefixFilter discards all entries whose mount points
+// do not start with a prefix specified
+func PrefixFilter(prefix string) FilterFunc {
+	return func(m *Info) (bool, bool) {
+		skip := !strings.HasPrefix(m.Mountpoint, prefix)
+		return skip, false
+	}
+}
+
+// SingleEntryFilter looks for a specific entry
+func SingleEntryFilter(mp string) FilterFunc {
+	return func(m *Info) (bool, bool) {
+		if m.Mountpoint == mp {
+			return false, true // don't skip, stop now
+		}
+		return true, false // skip, keep going
+	}
+}
+
+// GetMounts retrieves a list of mounts for the current running process,
+// with an optional filter applied (use nil for no filter).
+func GetMounts(f FilterFunc) ([]*Info, error) {
+	return parseMountTable(f)
 }
 
 // Mounted determines if a specified mountpoint has been mounted.
 // On Linux it looks at /proc/self/mountinfo.
 func Mounted(mountpoint string) (bool, error) {
-	entries, err := parseMountTable()
+	entries, err := GetMounts(SingleEntryFilter(mountpoint))
 	if err != nil {
 		return false, err
 	}
 
-	// Search the table for the mountpoint
-	for _, e := range entries {
-		if e.Mountpoint == mountpoint {
-			return true, nil
-		}
-	}
-	return false, nil
+	return len(entries) > 0, nil
 }
 
 // Mount will mount filesystem according to the specified configuration, on the
@@ -66,7 +88,7 @@ func Unmount(target string) error {
 // RecursiveUnmount unmounts the target and all mounts underneath, starting with
 // the deepsest mount first.
 func RecursiveUnmount(target string) error {
-	mounts, err := GetMounts()
+	mounts, err := parseMountTable(PrefixFilter(target))
 	if err != nil {
 		return err
 	}
@@ -77,9 +99,6 @@ func RecursiveUnmount(target string) error {
 	})
 
 	for i, m := range mounts {
-		if !strings.HasPrefix(m.Mountpoint, target) {
-			continue
-		}
 		logrus.Debugf("Trying to unmount %s", m.Mountpoint)
 		err = unmount(m.Mountpoint, mntDetach)
 		if err != nil {

+ 1 - 1
pkg/mount/mount_unix_test.go

@@ -129,7 +129,7 @@ func TestMountReadonly(t *testing.T) {
 }
 
 func TestGetMounts(t *testing.T) {
-	mounts, err := GetMounts()
+	mounts, err := GetMounts(nil)
 	if err != nil {
 		t.Fatal(err)
 	}

+ 1 - 1
pkg/mount/mounter_linux_test.go

@@ -121,7 +121,7 @@ func ensureUnmount(t *testing.T, mnt string) {
 
 // validateMount checks that mnt has the given options
 func validateMount(t *testing.T, mnt string, opts, optional, vfs string) {
-	info, err := GetMounts()
+	info, err := GetMounts(nil)
 	if err != nil {
 		t.Fatal(err)
 	}

+ 15 - 1
pkg/mount/mountinfo_freebsd.go

@@ -15,7 +15,7 @@ import (
 
 // Parse /proc/self/mountinfo because comparing Dev and ino does not work from
 // bind mounts.
-func parseMountTable() ([]*Info, error) {
+func parseMountTable(filter FilterFunc) ([]*Info, error) {
 	var rawEntries *C.struct_statfs
 
 	count := int(C.getmntinfo(&rawEntries, C.MNT_WAIT))
@@ -32,10 +32,24 @@ func parseMountTable() ([]*Info, error) {
 	var out []*Info
 	for _, entry := range entries {
 		var mountinfo Info
+		var skip, stop bool
 		mountinfo.Mountpoint = C.GoString(&entry.f_mntonname[0])
+
+		if filter != nil {
+			// filter out entries we're not interested in
+			skip, stop = filter(p)
+			if skip {
+				continue
+			}
+		}
+
 		mountinfo.Source = C.GoString(&entry.f_mntfromname[0])
 		mountinfo.Fstype = C.GoString(&entry.f_fstypename[0])
+
 		out = append(out, &mountinfo)
+		if stop {
+			break
+		}
 	}
 	return out, nil
 }

+ 15 - 4
pkg/mount/mountinfo_linux.go

@@ -28,17 +28,17 @@ const (
 
 // Parse /proc/self/mountinfo because comparing Dev and ino does not work from
 // bind mounts
-func parseMountTable() ([]*Info, error) {
+func parseMountTable(filter FilterFunc) ([]*Info, error) {
 	f, err := os.Open("/proc/self/mountinfo")
 	if err != nil {
 		return nil, err
 	}
 	defer f.Close()
 
-	return parseInfoFile(f)
+	return parseInfoFile(f, filter)
 }
 
-func parseInfoFile(r io.Reader) ([]*Info, error) {
+func parseInfoFile(r io.Reader, filter FilterFunc) ([]*Info, error) {
 	var (
 		s   = bufio.NewScanner(r)
 		out = []*Info{}
@@ -53,6 +53,7 @@ func parseInfoFile(r io.Reader) ([]*Info, error) {
 			p              = &Info{}
 			text           = s.Text()
 			optionalFields string
+			skip, stop     bool
 		)
 
 		if _, err := fmt.Sscanf(text, mountinfoFormat,
@@ -60,6 +61,13 @@ func parseInfoFile(r io.Reader) ([]*Info, error) {
 			&p.Root, &p.Mountpoint, &p.Opts, &optionalFields); err != nil {
 			return nil, fmt.Errorf("Scanning '%s' failed: %s", text, err)
 		}
+		if filter != nil {
+			// filter out entries we're not interested in
+			skip, stop = filter(p)
+			if skip {
+				continue
+			}
+		}
 		// Safe as mountinfo encodes mountpoints with spaces as \040.
 		index := strings.Index(text, " - ")
 		postSeparatorFields := strings.Fields(text[index+3:])
@@ -75,6 +83,9 @@ func parseInfoFile(r io.Reader) ([]*Info, error) {
 		p.Source = postSeparatorFields[1]
 		p.VfsOpts = strings.Join(postSeparatorFields[2:], " ")
 		out = append(out, p)
+		if stop {
+			break
+		}
 	}
 	return out, nil
 }
@@ -89,5 +100,5 @@ func PidMountInfo(pid int) ([]*Info, error) {
 	}
 	defer f.Close()
 
-	return parseInfoFile(f)
+	return parseInfoFile(f, nil)
 }

+ 30 - 4
pkg/mount/mountinfo_linux_test.go

@@ -5,6 +5,8 @@ package mount // import "github.com/docker/docker/pkg/mount"
 import (
 	"bytes"
 	"testing"
+
+	"github.com/gotestyourself/gotestyourself/assert"
 )
 
 const (
@@ -424,7 +426,7 @@ const (
 
 func TestParseFedoraMountinfo(t *testing.T) {
 	r := bytes.NewBuffer([]byte(fedoraMountinfo))
-	_, err := parseInfoFile(r)
+	_, err := parseInfoFile(r, nil)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -432,7 +434,7 @@ func TestParseFedoraMountinfo(t *testing.T) {
 
 func TestParseUbuntuMountinfo(t *testing.T) {
 	r := bytes.NewBuffer([]byte(ubuntuMountInfo))
-	_, err := parseInfoFile(r)
+	_, err := parseInfoFile(r, nil)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -440,7 +442,7 @@ func TestParseUbuntuMountinfo(t *testing.T) {
 
 func TestParseGentooMountinfo(t *testing.T) {
 	r := bytes.NewBuffer([]byte(gentooMountinfo))
-	_, err := parseInfoFile(r)
+	_, err := parseInfoFile(r, nil)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -448,7 +450,7 @@ func TestParseGentooMountinfo(t *testing.T) {
 
 func TestParseFedoraMountinfoFields(t *testing.T) {
 	r := bytes.NewBuffer([]byte(fedoraMountinfo))
-	infos, err := parseInfoFile(r)
+	infos, err := parseInfoFile(r, nil)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -474,3 +476,27 @@ func TestParseFedoraMountinfoFields(t *testing.T) {
 		t.Fatalf("expected %#v, got %#v", mi, infos[0])
 	}
 }
+
+func TestParseMountinfoFilters(t *testing.T) {
+	r := bytes.NewReader([]byte(fedoraMountinfo))
+
+	infos, err := parseInfoFile(r, SingleEntryFilter("/sys/fs/cgroup"))
+	assert.NilError(t, err)
+	assert.Equal(t, 1, len(infos))
+
+	r.Reset([]byte(fedoraMountinfo))
+	infos, err = parseInfoFile(r, SingleEntryFilter("nonexistent"))
+	assert.NilError(t, err)
+	assert.Equal(t, 0, len(infos))
+
+	r.Reset([]byte(fedoraMountinfo))
+	infos, err = parseInfoFile(r, PrefixFilter("/sys"))
+	assert.NilError(t, err)
+	// there are 18 entries starting with /sys in fedoraMountinfo
+	assert.Equal(t, 18, len(infos))
+
+	r.Reset([]byte(fedoraMountinfo))
+	infos, err = parseInfoFile(r, PrefixFilter("nonexistent"))
+	assert.NilError(t, err)
+	assert.Equal(t, 0, len(infos))
+}

+ 1 - 1
pkg/mount/mountinfo_unsupported.go

@@ -7,6 +7,6 @@ import (
 	"runtime"
 )
 
-func parseMountTable() ([]*Info, error) {
+func parseMountTable(f FilterFunc) ([]*Info, error) {
 	return nil, fmt.Errorf("mount.parseMountTable is not implemented on %s/%s", runtime.GOOS, runtime.GOARCH)
 }

+ 1 - 1
pkg/mount/mountinfo_windows.go

@@ -1,6 +1,6 @@
 package mount // import "github.com/docker/docker/pkg/mount"
 
-func parseMountTable() ([]*Info, error) {
+func parseMountTable(f FilterFunc) ([]*Info, error) {
 	// Do NOT return an error!
 	return nil, nil
 }

+ 1 - 1
volume/local/local.go

@@ -66,7 +66,7 @@ func New(scope string, rootIDs idtools.IDPair) (*Root, error) {
 		return nil, err
 	}
 
-	mountInfos, err := mount.GetMounts()
+	mountInfos, err := mount.GetMounts(nil)
 	if err != nil {
 		logrus.Debugf("error looking up mounts for local volume cleanup: %v", err)
 	}

+ 1 - 1
volume/local/local_test.go

@@ -215,7 +215,7 @@ func TestCreateWithOpts(t *testing.T) {
 		}
 	}()
 
-	mountInfos, err := mount.GetMounts()
+	mountInfos, err := mount.GetMounts(nil)
 	if err != nil {
 		t.Fatal(err)
 	}