浏览代码

osfs: improve isSubDir check

Nicola Murino 4 年之前
父节点
当前提交
bf708cb8bc
共有 2 个文件被更改,包括 18 次插入1 次删除
  1. 9 0
      sftpd/sftpd_test.go
  2. 9 1
      vfs/osfs.go

+ 9 - 0
sftpd/sftpd_test.go

@@ -887,6 +887,9 @@ func TestEscapeHomeDir(t *testing.T) {
 	usePubKey := true
 	user, _, err := httpd.AddUser(getTestUser(usePubKey), http.StatusOK)
 	assert.NoError(t, err)
+	dirOutsideHome := filepath.Join(homeBasePath, defaultUsername+"1", "dir")
+	err = os.MkdirAll(dirOutsideHome, os.ModePerm)
+	assert.NoError(t, err)
 	client, err := getSftpClient(user, usePubKey)
 	if assert.NoError(t, err) {
 		defer client.Close()
@@ -899,6 +902,10 @@ func TestEscapeHomeDir(t *testing.T) {
 		assert.Error(t, err, "reading a symbolic link outside home dir should not succeeded")
 		err = os.Remove(linkPath)
 		assert.NoError(t, err)
+		err = os.Symlink(dirOutsideHome, linkPath)
+		assert.NoError(t, err)
+		_, err := client.ReadDir(testDir)
+		assert.Error(t, err, "reading a symbolic link outside home dir should not succeeded")
 		testFilePath := filepath.Join(homeBasePath, testFileName)
 		testFileSize := int64(65535)
 		err = createTestFile(testFilePath, testFileSize)
@@ -928,6 +935,8 @@ func TestEscapeHomeDir(t *testing.T) {
 	assert.NoError(t, err)
 	err = os.RemoveAll(user.GetHomeDir())
 	assert.NoError(t, err)
+	err = os.RemoveAll(filepath.Join(homeBasePath, defaultUsername+"1"))
+	assert.NoError(t, err)
 }
 
 func TestHomeSpecialChars(t *testing.T) {

+ 9 - 1
vfs/osfs.go

@@ -413,7 +413,15 @@ func (fs *OsFs) isSubDir(sub, rootPath string) error {
 		fsLog(fs, logger.LevelWarn, "invalid root path %#v: %v", rootPath, err)
 		return err
 	}
-	if !strings.HasPrefix(sub, parent) {
+	if parent == sub {
+		return nil
+	}
+	if len(sub) < len(parent) {
+		err = fmt.Errorf("path %#v is not inside: %#v", sub, parent)
+		fsLog(fs, logger.LevelWarn, "error: %v ", err)
+		return err
+	}
+	if !strings.HasPrefix(sub, parent+string(os.PathSeparator)) {
 		err = fmt.Errorf("path %#v is not inside: %#v", sub, parent)
 		fsLog(fs, logger.LevelWarn, "error: %v ", err)
 		return err