소스 검색

move stat to base connection and differentiate between Stat and Lstat

we will use Lstat once it will be exposed in pkg/sftp
Nicola Murino 4 년 전
부모
커밋
2746c0b0f1
9개의 변경된 파일66개의 추가작업 그리고 14개의 파일을 삭제
  1. 8 0
      common/connection.go
  2. 32 0
      common/connection_test.go
  3. 5 3
      common/transfer_test.go
  4. 2 2
      ftpd/handler.go
  5. 1 2
      go.mod
  6. 3 2
      go.sum
  7. 3 3
      sftpd/handler.go
  8. 10 0
      sftpd/sftpd_test.go
  9. 2 2
      webdavd/handler.go

+ 8 - 0
common/connection.go

@@ -440,6 +440,14 @@ func (c *BaseConnection) getPathForSetStatPerms(fsPath, virtualPath string) stri
 	return pathForPerms
 }
 
+// DoStat execute a Stat if mode = 0, Lstat if mode = 1
+func (c *BaseConnection) DoStat(fsPath string, mode int) (os.FileInfo, error) {
+	if mode == 1 {
+		return c.Fs.Lstat(c.getRealFsPath(fsPath))
+	}
+	return c.Fs.Stat(c.getRealFsPath(fsPath))
+}
+
 // SetStat set StatAttributes for the specified fsPath
 func (c *BaseConnection) SetStat(fsPath, virtualPath string, attributes *StatAttributes) error {
 	if Config.SetstatMode == 1 {

+ 32 - 0
common/connection_test.go

@@ -11,6 +11,7 @@ import (
 
 	"github.com/pkg/sftp"
 	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 
 	"github.com/drakkan/sftpgo/dataprovider"
 	"github.com/drakkan/sftpgo/vfs"
@@ -434,6 +435,37 @@ func TestCreateSymlink(t *testing.T) {
 	assert.NoError(t, err)
 }
 
+func TestDoStat(t *testing.T) {
+	testFile := filepath.Join(os.TempDir(), "afile.txt")
+	fs := vfs.NewOsFs("123", os.TempDir(), nil)
+	u := dataprovider.User{
+		Username: "user",
+		HomeDir:  os.TempDir(),
+	}
+	u.Permissions = make(map[string][]string)
+	u.Permissions["/"] = []string{dataprovider.PermAny}
+	err := ioutil.WriteFile(testFile, []byte("data"), os.ModePerm)
+	require.NoError(t, err)
+	err = os.Symlink(testFile, testFile+".sym")
+	require.NoError(t, err)
+	conn := NewBaseConnection(fs.ConnectionID(), ProtocolSFTP, u, fs)
+	infoStat, err := conn.DoStat(testFile+".sym", 0)
+	if assert.NoError(t, err) {
+		assert.Equal(t, int64(4), infoStat.Size())
+	}
+	infoLstat, err := conn.DoStat(testFile+".sym", 1)
+	if assert.NoError(t, err) {
+		assert.NotEqual(t, int64(4), infoLstat.Size())
+	}
+	assert.False(t, os.SameFile(infoStat, infoLstat))
+
+	err = os.Remove(testFile)
+	assert.NoError(t, err)
+	err = os.Remove(testFile + ".sym")
+	assert.NoError(t, err)
+	assert.Len(t, conn.GetTransfers(), 0)
+}
+
 func TestSetStat(t *testing.T) {
 	oldSetStatMode := Config.SetstatMode
 	Config.SetstatMode = 1

+ 5 - 3
common/transfer_test.go

@@ -9,6 +9,7 @@ import (
 	"time"
 
 	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 
 	"github.com/drakkan/sftpgo/dataprovider"
 	"github.com/drakkan/sftpgo/vfs"
@@ -93,9 +94,7 @@ func TestRealPath(t *testing.T) {
 	u.Permissions = make(map[string][]string)
 	u.Permissions["/"] = []string{dataprovider.PermAny}
 	file, err := os.Create(testFile)
-	if !assert.NoError(t, err) {
-		assert.FailNow(t, "unable to open test file")
-	}
+	require.NoError(t, err)
 	conn := NewBaseConnection(fs.ConnectionID(), ProtocolSFTP, u, fs)
 	transfer := NewBaseTransfer(file, conn, nil, testFile, "/transfer_test_file", TransferUpload, 0, 0, 0, true, fs)
 	rPath := transfer.GetRealFsPath(testFile)
@@ -111,6 +110,9 @@ func TestRealPath(t *testing.T) {
 	assert.Equal(t, testFile, rPath)
 	rPath = transfer.GetRealFsPath("")
 	assert.Empty(t, rPath)
+	err = os.Remove(testFile)
+	assert.NoError(t, err)
+	assert.Len(t, conn.GetTransfers(), 0)
 }
 
 func TestTruncate(t *testing.T) {

+ 2 - 2
ftpd/handler.go

@@ -148,9 +148,9 @@ func (c *Connection) Stat(name string) (os.FileInfo, error) {
 	if err != nil {
 		return nil, c.GetFsError(err)
 	}
-	fi, err := c.Fs.Stat(p)
+	fi, err := c.DoStat(p, 0)
 	if err != nil {
-		c.Log(logger.LevelWarn, "error running stat on path %#v: %+v", p, err)
+		c.Log(logger.LevelDebug, "error running stat on path %#v: %+v", p, err)
 		return nil, c.GetFsError(err)
 	}
 	return fi, nil

+ 1 - 2
go.mod

@@ -24,7 +24,7 @@ require (
 	github.com/otiai10/copy v1.2.0
 	github.com/pelletier/go-toml v1.8.0 // indirect
 	github.com/pires/go-proxyproto v0.1.3
-	github.com/pkg/sftp v1.11.1-0.20200819110714-3ee8d0ba91c0
+	github.com/pkg/sftp v1.11.1-0.20200825160622-06ab92ee3917
 	github.com/prometheus/client_golang v1.7.1
 	github.com/prometheus/common v0.13.0 // indirect
 	github.com/rs/cors v1.7.1-0.20200626170627-8b4a00bd362b
@@ -50,7 +50,6 @@ require (
 
 replace (
 	github.com/jlaffaye/ftp => github.com/drakkan/ftp v0.0.0-20200730125632-b21eac28818c
-	github.com/pkg/sftp => github.com/drakkan/sftp v0.0.0-20200824132209-4da3253ee1d6
 	golang.org/x/crypto => github.com/drakkan/crypto v0.0.0-20200824205004-9f5ce89c1796
 	golang.org/x/net => github.com/drakkan/net v0.0.0-20200824204746-8b31adf087bf
 )

+ 3 - 2
go.sum

@@ -109,8 +109,6 @@ github.com/drakkan/ftp v0.0.0-20200730125632-b21eac28818c h1:QSXIWohSNn0negBVSKE
 github.com/drakkan/ftp v0.0.0-20200730125632-b21eac28818c/go.mod h1:2lmrmq866uF2tnje75wQHzmPXhmSWUt7Gyx2vgK1RCU=
 github.com/drakkan/net v0.0.0-20200824204746-8b31adf087bf h1:MbeUXErR+xQ1Yvk+E6wYBKvgK8nvDiXk00jNEyDRvE8=
 github.com/drakkan/net v0.0.0-20200824204746-8b31adf087bf/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
-github.com/drakkan/sftp v0.0.0-20200824132209-4da3253ee1d6 h1:2odjq46RgD0/qTzAiX6Rxxq+7jAzu6gGoHBHcgIS0og=
-github.com/drakkan/sftp v0.0.0-20200824132209-4da3253ee1d6/go.mod h1:i24A96cQ6ZvWut9G/Uv3LvC4u3VebGsBR5JFvPyChLc=
 github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
 github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs=
 github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU=
@@ -366,6 +364,9 @@ github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE
 github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
 github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
 github.com/pkg/profile v1.2.1/go.mod h1:hJw3o1OdXxsrSjjVksARp5W95eeEaEfptyVZyv6JUPA=
+github.com/pkg/sftp v1.10.1/go.mod h1:lYOWFsE0bwd1+KfKJaKeuokY15vzFx25BLbzYYoAxZI=
+github.com/pkg/sftp v1.11.1-0.20200825160622-06ab92ee3917 h1:NO+7wv5cAXYMhTpVX0e97zH349VVp8c2dB0d/SXfmkg=
+github.com/pkg/sftp v1.11.1-0.20200825160622-06ab92ee3917/go.mod h1:i24A96cQ6ZvWut9G/Uv3LvC4u3VebGsBR5JFvPyChLc=
 github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
 github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
 github.com/posener/complete v1.1.1/go.mod h1:em0nMJCgc9GFtwrmVmEMR/ZL6WyhyjMBndrE9hABlRI=

+ 3 - 3
sftpd/handler.go

@@ -188,9 +188,9 @@ func (c *Connection) Filelist(request *sftp.Request) (sftp.ListerAt, error) {
 			return nil, sftp.ErrSSHFxPermissionDenied
 		}
 
-		s, err := c.Fs.Stat(p)
+		s, err := c.DoStat(p, 0)
 		if err != nil {
-			c.Log(logger.LevelWarn, "error running stat on path %#v: %+v", p, err)
+			c.Log(logger.LevelDebug, "error running stat on path %#v: %+v", p, err)
 			return nil, c.GetFsError(err)
 		}
 
@@ -202,7 +202,7 @@ func (c *Connection) Filelist(request *sftp.Request) (sftp.ListerAt, error) {
 
 		s, err := c.Fs.Readlink(p)
 		if err != nil {
-			c.Log(logger.LevelWarn, "error running readlink on path %#v: %+v", p, err)
+			c.Log(logger.LevelDebug, "error running readlink on path %#v: %+v", p, err)
 			return nil, c.GetFsError(err)
 		}
 

+ 10 - 0
sftpd/sftpd_test.go

@@ -2662,6 +2662,8 @@ func TestTruncateQuotaLimits(t *testing.T) {
 			assert.NoError(t, err)
 			assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles)
 			assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize)
+			_, err = f.Seek(expectedQuotaSize, io.SeekStart)
+			assert.NoError(t, err)
 			n, err = f.Write(data)
 			assert.NoError(t, err)
 			assert.Equal(t, len(data), n)
@@ -2672,6 +2674,8 @@ func TestTruncateQuotaLimits(t *testing.T) {
 			assert.NoError(t, err)
 			assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles)
 			assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize)
+			_, err = f.Seek(expectedQuotaSize, io.SeekStart)
+			assert.NoError(t, err)
 			n, err = f.Write(data)
 			assert.NoError(t, err)
 			assert.Equal(t, len(data), n)
@@ -2723,6 +2727,8 @@ func TestTruncateQuotaLimits(t *testing.T) {
 			assert.NoError(t, err)
 			assert.Equal(t, 1, user.UsedQuotaFiles)
 			assert.Equal(t, int64(11), user.UsedQuotaSize)
+			_, err = f.Seek(int64(11), io.SeekStart)
+			assert.NoError(t, err)
 			n, err = f.Write(data)
 			assert.NoError(t, err)
 			assert.Equal(t, len(data), n)
@@ -2732,6 +2738,8 @@ func TestTruncateQuotaLimits(t *testing.T) {
 			assert.NoError(t, err)
 			assert.Equal(t, 1, user.UsedQuotaFiles)
 			assert.Equal(t, int64(5), user.UsedQuotaSize)
+			_, err = f.Seek(int64(5), io.SeekStart)
+			assert.NoError(t, err)
 			n, err = f.Write(data)
 			assert.NoError(t, err)
 			assert.Equal(t, len(data), n)
@@ -2741,6 +2749,8 @@ func TestTruncateQuotaLimits(t *testing.T) {
 			assert.NoError(t, err)
 			assert.Equal(t, 1, user.UsedQuotaFiles)
 			assert.Equal(t, int64(12), user.UsedQuotaSize)
+			_, err = f.Seek(int64(12), io.SeekStart)
+			assert.NoError(t, err)
 			_, err = f.Write(data)
 			if assert.Error(t, err) {
 				assert.Contains(t, err.Error(), common.ErrQuotaExceeded.Error())

+ 2 - 2
webdavd/handler.go

@@ -105,9 +105,9 @@ func (c *Connection) Stat(ctx context.Context, name string) (os.FileInfo, error)
 	if err != nil {
 		return nil, c.GetFsError(err)
 	}
-	fi, err := c.Fs.Stat(p)
+	fi, err := c.DoStat(p, 0)
 	if err != nil {
-		c.Log(logger.LevelWarn, "error running stat on path %#v: %+v", p, err)
+		c.Log(logger.LevelDebug, "error running stat on path %#v: %+v", p, err)
 		return nil, c.GetFsError(err)
 	}
 	return fi, err