Sfoglia il codice sorgente

webdav: fix GET as PROPFIND if a prefix is defined

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
Nicola Murino 2 anni fa
parent
commit
8e86782d85
2 ha cambiato i file con 30 aggiunte e 0 eliminazioni
  1. 26 0
      internal/webdavd/internal_test.go
  2. 4 0
      internal/webdavd/server.go

+ 26 - 0
internal/webdavd/internal_test.go

@@ -34,6 +34,7 @@ import (
 	"github.com/eikenb/pipeat"
 	"github.com/sftpgo/sdk"
 	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 
 	"github.com/drakkan/sftpgo/v2/internal/common"
 	"github.com/drakkan/sftpgo/v2/internal/dataprovider"
@@ -636,6 +637,31 @@ func TestFileAccessErrors(t *testing.T) {
 	}
 }
 
+func TestCheckRequestMethodWithPrefix(t *testing.T) {
+	user := dataprovider.User{
+		BaseUser: sdk.BaseUser{
+			HomeDir: filepath.Clean(os.TempDir()),
+			Permissions: map[string][]string{
+				"/": {dataprovider.PermAny},
+			},
+		},
+	}
+	fs := vfs.NewOsFs("connID", user.HomeDir, "")
+	connection := &Connection{
+		BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, "", "", user),
+	}
+	server := webDavServer{
+		binding: Binding{
+			Prefix: "/dav",
+		},
+	}
+	req, err := http.NewRequest(http.MethodGet, "/../dav", nil)
+	require.NoError(t, err)
+	server.checkRequestMethod(context.Background(), req, connection)
+	require.Equal(t, "PROPFIND", req.Method)
+	require.Equal(t, "1", req.Header.Get("Depth"))
+}
+
 func TestContentType(t *testing.T) {
 	user := dataprovider.User{
 		BaseUser: sdk.BaseUser{

+ 4 - 0
internal/webdavd/server.go

@@ -26,6 +26,7 @@ import (
 	"path"
 	"path/filepath"
 	"runtime/debug"
+	"strings"
 	"time"
 
 	"github.com/drakkan/webdav"
@@ -137,6 +138,9 @@ func (s *webDavServer) checkRequestMethod(ctx context.Context, r *http.Request,
 	// see RFC4918, section 9.4
 	if r.Method == http.MethodGet || r.Method == http.MethodHead {
 		p := path.Clean(r.URL.Path)
+		if s.binding.Prefix != "" {
+			p = strings.TrimPrefix(p, s.binding.Prefix)
+		}
 		info, err := connection.Stat(ctx, p)
 		if err == nil && info.IsDir() {
 			if r.Method == http.MethodHead {