Przeglądaj źródła

Improve FollowLink to handle recursive link and be more strick

Guillaume J. Charmes 11 lat temu
rodzic
commit
8fd9633a6b
5 zmienionych plików z 51 dodań i 40 usunięć
  1. 1 1
      container.go
  2. 29 19
      utils/fs.go
  3. 19 20
      utils/fs_test.go
  4. 1 0
      utils/testdata/fs/b/h
  5. 1 0
      utils/testdata/fs/g

+ 1 - 1
container.go

@@ -827,7 +827,7 @@ func (container *Container) createVolumes() error {
 
 		// Create the mountpoint
 		volPath = path.Join(container.RootfsPath(), volPath)
-		rootVolPath, err := utils.FollowSymlink(volPath, container.RootfsPath())
+		rootVolPath, err := utils.FollowSymlinkInScope(volPath, container.RootfsPath())
 		if err != nil {
 			panic(err)
 		}

+ 29 - 19
utils/fs.go

@@ -1,6 +1,7 @@
 package utils
 
 import (
+	"fmt"
 	"os"
 	"path/filepath"
 	"strings"
@@ -38,43 +39,52 @@ func TreeSize(dir string) (size int64, err error) {
 // FollowSymlink will follow an existing link and scope it to the root
 // path provided.
 func FollowSymlinkInScope(link, root string) (string, error) {
-	prev := "."
+	prev := "/"
 
 	root, err := filepath.Abs(root)
 	if err != nil {
 		return "", err
 	}
-	root = filepath.Clean(root)
-	link, err := filepath.Abs(link)
+
+	link, err = filepath.Abs(link)
 	if err != nil {
 		return "", err
 	}
-	link = filepath.Clean(link)
+
+	if !strings.HasPrefix(filepath.Dir(link), root) {
+		return "", fmt.Errorf("%s is not within %s", link, root)
+	}
 
 	for _, p := range strings.Split(link, "/") {
 		prev = filepath.Join(prev, p)
 		prev = filepath.Clean(prev)
 
-		stat, err := os.Lstat(prev)
-		if err != nil {
-			if os.IsNotExist(err) {
-				continue
-			}
-			return "", err
-		}
-		if stat.Mode()&os.ModeSymlink == os.ModeSymlink {
-			dest, err := os.Readlink(prev)
+		for {
+			stat, err := os.Lstat(prev)
 			if err != nil {
+				if os.IsNotExist(err) {
+					break
+				}
 				return "", err
 			}
+			if stat.Mode()&os.ModeSymlink == os.ModeSymlink {
+				dest, err := os.Readlink(prev)
+				if err != nil {
+					return "", err
+				}
+
+				switch dest[0] {
+				case '/':
+					prev = filepath.Join(root, dest)
+				case '.':
+					prev, _ = filepath.Abs(prev)
 
-			switch dest[0] {
-			case '/':
-				prev = filepath.Join(root, dest)
-			case '.':
-				if prev = filepath.Clean(filepath.Join(filepath.Dir(prev), dest)); len(prev) < len(root) {
-					prev = filepath.Join(root, filepath.Base(dest))
+					if prev = filepath.Clean(filepath.Join(filepath.Dir(prev), dest)); len(prev) < len(root) {
+						prev = filepath.Join(root, filepath.Base(dest))
+					}
 				}
+			} else {
+				break
 			}
 		}
 	}

+ 19 - 20
utils/fs_test.go

@@ -1,15 +1,14 @@
 package utils
 
 import (
-	"os"
 	"path/filepath"
 	"testing"
 )
 
-func abs(p string) string {
+func abs(t *testing.T, p string) string {
 	o, err := filepath.Abs(p)
 	if err != nil {
-		panic(err)
+		t.Fatal(err)
 	}
 	return o
 }
@@ -17,36 +16,31 @@ func abs(p string) string {
 func TestFollowSymLinkNormal(t *testing.T) {
 	link := "testdata/fs/a/d/c/data"
 
-	rewrite, err := FollowSymlink(link, "test")
+	rewrite, err := FollowSymlinkInScope(link, "testdata")
 	if err != nil {
 		t.Fatal(err)
 	}
 
-	if expected := abs("test/b/c/data"); expected != rewrite {
+	if expected := abs(t, "testdata/b/c/data"); expected != rewrite {
 		t.Fatalf("Expected %s got %s", expected, rewrite)
 	}
 }
 
 func TestFollowSymLinkRandomString(t *testing.T) {
-	rewrite, err := FollowSymlink("toto", "test")
-	if err != nil {
-		t.Fatal(err)
-	}
-
-	if rewrite != "toto" {
-		t.Fatalf("Expected toto got %s", rewrite)
+	if _, err := FollowSymlinkInScope("toto", "testdata"); err == nil {
+		t.Fatal("Random string should fail but didn't")
 	}
 }
 
 func TestFollowSymLinkLastLink(t *testing.T) {
 	link := "testdata/fs/a/d"
 
-	rewrite, err := FollowSymlink(link, "test")
+	rewrite, err := FollowSymlinkInScope(link, "testdata")
 	if err != nil {
 		t.Fatal(err)
 	}
 
-	if expected := abs("test/b"); expected != rewrite {
+	if expected := abs(t, "testdata/b"); expected != rewrite {
 		t.Fatalf("Expected %s got %s", expected, rewrite)
 	}
 }
@@ -54,31 +48,36 @@ func TestFollowSymLinkLastLink(t *testing.T) {
 func TestFollowSymLinkRelativeLink(t *testing.T) {
 	link := "testdata/fs/a/e/c/data"
 
-	rewrite, err := FollowSymlink(link, "test")
+	rewrite, err := FollowSymlinkInScope(link, "testdata")
 	if err != nil {
 		t.Fatal(err)
 	}
 
-	if expected := abs("testdata/fs/a/e/c/data"); expected != rewrite {
+	if expected := abs(t, "testdata/fs/b/c/data"); expected != rewrite {
 		t.Fatalf("Expected %s got %s", expected, rewrite)
 	}
 }
 
 func TestFollowSymLinkRelativeLinkScope(t *testing.T) {
 	link := "testdata/fs/a/f"
-	pwd, err := os.Getwd()
+
+	rewrite, err := FollowSymlinkInScope(link, "testdata")
 	if err != nil {
 		t.Fatal(err)
 	}
 
-	root := filepath.Join(pwd, "testdata")
+	if expected := abs(t, "testdata/test"); expected != rewrite {
+		t.Fatalf("Expected %s got %s", expected, rewrite)
+	}
+
+	link = "testdata/fs/b/h"
 
-	rewrite, err := FollowSymlink(link, root)
+	rewrite, err = FollowSymlinkInScope(link, "testdata")
 	if err != nil {
 		t.Fatal(err)
 	}
 
-	if expected := abs("testdata/test"); expected != rewrite {
+	if expected := abs(t, "testdata/root"); expected != rewrite {
 		t.Fatalf("Expected %s got %s", expected, rewrite)
 	}
 }

+ 1 - 0
utils/testdata/fs/b/h

@@ -0,0 +1 @@
+../g

+ 1 - 0
utils/testdata/fs/g

@@ -0,0 +1 @@
+../../../../../../../../../../../../root