add DirLister interface

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
Nicola Murino 2024-02-15 20:53:56 +01:00
parent c60eb050ef
commit 1ff55bbfa7
No known key found for this signature in database
GPG key ID: 935D2952DEC4EECF
35 changed files with 1362 additions and 669 deletions

11
go.mod
View file

@ -21,7 +21,7 @@ require (
github.com/bmatcuk/doublestar/v4 v4.6.1 github.com/bmatcuk/doublestar/v4 v4.6.1
github.com/cockroachdb/cockroach-go/v2 v2.3.6 github.com/cockroachdb/cockroach-go/v2 v2.3.6
github.com/coreos/go-oidc/v3 v3.9.0 github.com/coreos/go-oidc/v3 v3.9.0
github.com/drakkan/webdav v0.0.0-20230227175313-32996838bcd8 github.com/drakkan/webdav v0.0.0-20240212101318-94e905cb9adb
github.com/eikenb/pipeat v0.0.0-20210730190139-06b3e6902001 github.com/eikenb/pipeat v0.0.0-20210730190139-06b3e6902001
github.com/fclairamb/ftpserverlib v0.22.0 github.com/fclairamb/ftpserverlib v0.22.0
github.com/fclairamb/go-log v0.4.1 github.com/fclairamb/go-log v0.4.1
@ -71,7 +71,7 @@ require (
golang.org/x/crypto v0.18.0 golang.org/x/crypto v0.18.0
golang.org/x/net v0.20.0 golang.org/x/net v0.20.0
golang.org/x/oauth2 v0.16.0 golang.org/x/oauth2 v0.16.0
golang.org/x/sys v0.16.0 golang.org/x/sys v0.17.0
golang.org/x/term v0.16.0 golang.org/x/term v0.16.0
golang.org/x/time v0.5.0 golang.org/x/time v0.5.0
google.golang.org/api v0.161.0 google.golang.org/api v0.161.0
@ -118,7 +118,9 @@ require (
github.com/google/s2a-go v0.1.7 // indirect github.com/google/s2a-go v0.1.7 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
github.com/googleapis/gax-go/v2 v2.12.0 // indirect github.com/googleapis/gax-go/v2 v2.12.0 // indirect
github.com/hashicorp/errwrap v1.0.0 // indirect
github.com/hashicorp/go-cleanhttp v0.5.2 // indirect github.com/hashicorp/go-cleanhttp v0.5.2 // indirect
github.com/hashicorp/go-multierror v1.1.1 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect github.com/hashicorp/hcl v1.0.0 // indirect
github.com/hashicorp/yamux v0.1.1 // indirect github.com/hashicorp/yamux v0.1.1 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect
@ -181,8 +183,9 @@ require (
) )
replace ( replace (
github.com/fclairamb/ftpserverlib => github.com/drakkan/ftpserverlib v0.0.0-20230820193955-e7243edeb89b github.com/fclairamb/ftpserverlib => github.com/drakkan/ftpserverlib v0.0.0-20240212100826-a241365cb085
github.com/jlaffaye/ftp => github.com/drakkan/ftp v0.0.0-20201114075148-9b9adce499a9 github.com/jlaffaye/ftp => github.com/drakkan/ftp v0.0.0-20240210102745-f1ffc43f78d2
github.com/pkg/sftp => github.com/drakkan/sftp v0.0.0-20240214104840-fbb0b8bdb30c
github.com/robfig/cron/v3 => github.com/drakkan/cron/v3 v3.0.0-20230222140221-217a1e4d96c0 github.com/robfig/cron/v3 => github.com/drakkan/cron/v3 v3.0.0-20230222140221-217a1e4d96c0
golang.org/x/crypto => github.com/drakkan/crypto v0.0.0-20231218163632-74b52eafd2c0 golang.org/x/crypto => github.com/drakkan/crypto v0.0.0-20231218163632-74b52eafd2c0
) )

23
go.sum
View file

@ -113,12 +113,14 @@ github.com/drakkan/cron/v3 v3.0.0-20230222140221-217a1e4d96c0 h1:EW9gIJRmt9lzk66
github.com/drakkan/cron/v3 v3.0.0-20230222140221-217a1e4d96c0/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= github.com/drakkan/cron/v3 v3.0.0-20230222140221-217a1e4d96c0/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
github.com/drakkan/crypto v0.0.0-20231218163632-74b52eafd2c0 h1:Yel8NcrK4jg+biIcTxnszKh0eIpF2Vj25XEygQcTweI= github.com/drakkan/crypto v0.0.0-20231218163632-74b52eafd2c0 h1:Yel8NcrK4jg+biIcTxnszKh0eIpF2Vj25XEygQcTweI=
github.com/drakkan/crypto v0.0.0-20231218163632-74b52eafd2c0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= github.com/drakkan/crypto v0.0.0-20231218163632-74b52eafd2c0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
github.com/drakkan/ftp v0.0.0-20201114075148-9b9adce499a9 h1:LPH1dEblAOO/LoG7yHPMtBLXhQmjaga91/DDjWk9jWA= github.com/drakkan/ftp v0.0.0-20240210102745-f1ffc43f78d2 h1:ufiGMPFBjndWSQOst9FNP11IuMqPblI2NXbpRMUWNhk=
github.com/drakkan/ftp v0.0.0-20201114075148-9b9adce499a9/go.mod h1:2lmrmq866uF2tnje75wQHzmPXhmSWUt7Gyx2vgK1RCU= github.com/drakkan/ftp v0.0.0-20240210102745-f1ffc43f78d2/go.mod h1:4p8lUl4vQ80L598CygL+3IFtm+3nggvvW/palOlViwE=
github.com/drakkan/ftpserverlib v0.0.0-20230820193955-e7243edeb89b h1:sCtiYerLxfOQrSludkwGwwXLlSVHxpvfmyOxjCOf0ec= github.com/drakkan/ftpserverlib v0.0.0-20240212100826-a241365cb085 h1:LAKYR9z9USKeyEQK91sRWldmMOjEHLOt2NuLDx+x1UQ=
github.com/drakkan/ftpserverlib v0.0.0-20230820193955-e7243edeb89b/go.mod h1:dI9/yw/KfJ0g4wmRK8ZukUfqakLr6ZTf9VDydKoLy90= github.com/drakkan/ftpserverlib v0.0.0-20240212100826-a241365cb085/go.mod h1:9rZ27KBV3xlXmjIfd6HynND28tse8ShZJ/NQkprCKno=
github.com/drakkan/webdav v0.0.0-20230227175313-32996838bcd8 h1:tdkLkSKtYd3WSDsZXGJDKsakiNstLQJPN5HjnqCkf2c= github.com/drakkan/sftp v0.0.0-20240214104840-fbb0b8bdb30c h1:usPo/2W6Dj2rugQiEml0pwmUfY/wUgW6nLGl+q98c5k=
github.com/drakkan/webdav v0.0.0-20230227175313-32996838bcd8/go.mod h1:zOVb1QDhwwqWn2L2qZ0U3swMSO4GTSNyIwXCGO/UGWE= github.com/drakkan/sftp v0.0.0-20240214104840-fbb0b8bdb30c/go.mod h1:KMKI0t3T6hfA+lTR/ssZdunHo+uwq7ghoN09/FSu3DY=
github.com/drakkan/webdav v0.0.0-20240212101318-94e905cb9adb h1:BLT+1m0U57PevUtACnTUoMErsLJyK2ydeNiVuX8R3Lk=
github.com/drakkan/webdav v0.0.0-20240212101318-94e905cb9adb/go.mod h1:zOVb1QDhwwqWn2L2qZ0U3swMSO4GTSNyIwXCGO/UGWE=
github.com/eikenb/pipeat v0.0.0-20210730190139-06b3e6902001 h1:/ZshrfQzayqRSBDodmp3rhNCHJCff+utvgBuWRbiqu4= github.com/eikenb/pipeat v0.0.0-20210730190139-06b3e6902001 h1:/ZshrfQzayqRSBDodmp3rhNCHJCff+utvgBuWRbiqu4=
github.com/eikenb/pipeat v0.0.0-20210730190139-06b3e6902001/go.mod h1:kltMsfRMTHSFdMbK66XdS8mfMW77+FZA1fGY1xYMF84= github.com/eikenb/pipeat v0.0.0-20210730190139-06b3e6902001/go.mod h1:kltMsfRMTHSFdMbK66XdS8mfMW77+FZA1fGY1xYMF84=
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
@ -216,11 +218,15 @@ github.com/googleapis/enterprise-certificate-proxy v0.3.2 h1:Vie5ybvEvT75RniqhfF
github.com/googleapis/enterprise-certificate-proxy v0.3.2/go.mod h1:VLSiSSBs/ksPL8kq3OBOQ6WRI2QnaFynd1DCjZ62+V0= github.com/googleapis/enterprise-certificate-proxy v0.3.2/go.mod h1:VLSiSSBs/ksPL8kq3OBOQ6WRI2QnaFynd1DCjZ62+V0=
github.com/googleapis/gax-go/v2 v2.12.0 h1:A+gCJKdRfqXkr+BIRGtZLibNXf0m1f9E4HG56etFpas= github.com/googleapis/gax-go/v2 v2.12.0 h1:A+gCJKdRfqXkr+BIRGtZLibNXf0m1f9E4HG56etFpas=
github.com/googleapis/gax-go/v2 v2.12.0/go.mod h1:y+aIqrI5eb1YGMVJfuV3185Ts/D7qKpsEkdD5+I6QGU= github.com/googleapis/gax-go/v2 v2.12.0/go.mod h1:y+aIqrI5eb1YGMVJfuV3185Ts/D7qKpsEkdD5+I6QGU=
github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA=
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ=
github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48= github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48=
github.com/hashicorp/go-hclog v0.9.2/go.mod h1:5CU+agLiy3J7N7QjHK5d05KxGsuXiQLrjA0H7acj2lQ= github.com/hashicorp/go-hclog v0.9.2/go.mod h1:5CU+agLiy3J7N7QjHK5d05KxGsuXiQLrjA0H7acj2lQ=
github.com/hashicorp/go-hclog v1.6.2 h1:NOtoftovWkDheyUM/8JW3QMiXyxJK3uHRK7wV04nD2I= github.com/hashicorp/go-hclog v1.6.2 h1:NOtoftovWkDheyUM/8JW3QMiXyxJK3uHRK7wV04nD2I=
github.com/hashicorp/go-hclog v1.6.2/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M= github.com/hashicorp/go-hclog v1.6.2/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M=
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
github.com/hashicorp/go-plugin v1.6.0 h1:wgd4KxHJTVGGqWBq4QPB1i5BZNEx9BR8+OFmHDmTk8A= github.com/hashicorp/go-plugin v1.6.0 h1:wgd4KxHJTVGGqWBq4QPB1i5BZNEx9BR8+OFmHDmTk8A=
github.com/hashicorp/go-plugin v1.6.0/go.mod h1:lBS5MtSSBZk0SHc66KACcjjlU6WzEVP/8pwz68aMkCI= github.com/hashicorp/go-plugin v1.6.0/go.mod h1:lBS5MtSSBZk0SHc66KACcjjlU6WzEVP/8pwz68aMkCI=
github.com/hashicorp/go-retryablehttp v0.7.5 h1:bJj+Pj19UZMIweq/iie+1u5YCdGrnxCT9yvm0e+Nd5M= github.com/hashicorp/go-retryablehttp v0.7.5 h1:bJj+Pj19UZMIweq/iie+1u5YCdGrnxCT9yvm0e+Nd5M=
@ -311,8 +317,6 @@ github.com/pires/go-proxyproto v0.7.0/go.mod h1:Vz/1JPY/OACxWGQNIRY2BeyDmpoaWmEP
github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 h1:KoWmjvw+nsYOo29YJK9vDA65RGE3NrOnUtO7a+RF9HU= github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 h1:KoWmjvw+nsYOo29YJK9vDA65RGE3NrOnUtO7a+RF9HU=
github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8/go.mod h1:HKlIX3XHQyzLZPlr7++PzdhaXEj94dEiJgZDTsxEqUI= github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8/go.mod h1:HKlIX3XHQyzLZPlr7++PzdhaXEj94dEiJgZDTsxEqUI=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/sftp v1.13.6 h1:JFZT4XbOU7l77xGSpOdW+pwIMqP044IyjXX6FGyEKFo=
github.com/pkg/sftp v1.13.6/go.mod h1:tz1ryNURKu77RL+GuCzmoJYxQczL3wLNNpPWagdg4Qk=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
@ -487,8 +491,9 @@ golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU=
golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y=
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=

View file

@ -297,7 +297,7 @@ func (c *BaseConnection) truncateOpenHandle(fsPath string, size int64) (int64, e
} }
// ListDir reads the directory matching virtualPath and returns a list of directory entries // ListDir reads the directory matching virtualPath and returns a list of directory entries
func (c *BaseConnection) ListDir(virtualPath string) ([]os.FileInfo, error) { func (c *BaseConnection) ListDir(virtualPath string) (*DirListerAt, error) {
if !c.User.HasPerm(dataprovider.PermListItems, virtualPath) { if !c.User.HasPerm(dataprovider.PermListItems, virtualPath) {
return nil, c.GetPermissionDeniedError() return nil, c.GetPermissionDeniedError()
} }
@ -305,12 +305,17 @@ func (c *BaseConnection) ListDir(virtualPath string) ([]os.FileInfo, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
files, err := fs.ReadDir(fsPath) lister, err := fs.ReadDir(fsPath)
if err != nil { if err != nil {
c.Log(logger.LevelDebug, "error listing directory: %+v", err) c.Log(logger.LevelDebug, "error listing directory: %+v", err)
return nil, c.GetFsError(fs, err) return nil, c.GetFsError(fs, err)
} }
return c.User.FilterListDir(files, virtualPath), nil return &DirListerAt{
virtualPath: virtualPath,
user: &c.User,
info: c.User.GetVirtualFoldersInfo(virtualPath),
lister: lister,
}, nil
} }
// CheckParentDirs tries to create the specified directory and any missing parent dirs // CheckParentDirs tries to create the specified directory and any missing parent dirs
@ -511,24 +516,42 @@ func (c *BaseConnection) RemoveDir(virtualPath string) error {
return nil return nil
} }
func (c *BaseConnection) doRecursiveRemoveDirEntry(virtualPath string, info os.FileInfo) error { func (c *BaseConnection) doRecursiveRemoveDirEntry(virtualPath string, info os.FileInfo, recursion int) error {
fs, fsPath, err := c.GetFsAndResolvedPath(virtualPath) fs, fsPath, err := c.GetFsAndResolvedPath(virtualPath)
if err != nil { if err != nil {
return err return err
} }
return c.doRecursiveRemove(fs, fsPath, virtualPath, info) return c.doRecursiveRemove(fs, fsPath, virtualPath, info, recursion)
} }
func (c *BaseConnection) doRecursiveRemove(fs vfs.Fs, fsPath, virtualPath string, info os.FileInfo) error { func (c *BaseConnection) doRecursiveRemove(fs vfs.Fs, fsPath, virtualPath string, info os.FileInfo, recursion int) error {
if info.IsDir() { if info.IsDir() {
entries, err := c.ListDir(virtualPath) if recursion >= util.MaxRecursion {
if err != nil { c.Log(logger.LevelError, "recursive rename failed, recursion too depth: %d", recursion)
return fmt.Errorf("unable to get contents for dir %q: %w", virtualPath, err) return util.ErrRecursionTooDeep
} }
for _, fi := range entries { recursion++
targetPath := path.Join(virtualPath, fi.Name()) lister, err := c.ListDir(virtualPath)
if err := c.doRecursiveRemoveDirEntry(targetPath, fi); err != nil { if err != nil {
return err return fmt.Errorf("unable to get lister for dir %q: %w", virtualPath, err)
}
defer lister.Close()
for {
entries, err := lister.Next(vfs.ListerBatchSize)
finished := errors.Is(err, io.EOF)
if err != nil && !finished {
return fmt.Errorf("unable to get content for dir %q: %w", virtualPath, err)
}
for _, fi := range entries {
targetPath := path.Join(virtualPath, fi.Name())
if err := c.doRecursiveRemoveDirEntry(targetPath, fi, recursion); err != nil {
return err
}
}
if finished {
lister.Close()
break
} }
} }
return c.RemoveDir(virtualPath) return c.RemoveDir(virtualPath)
@ -552,7 +575,7 @@ func (c *BaseConnection) RemoveAll(virtualPath string) error {
if err := c.IsRemoveDirAllowed(fs, fsPath, virtualPath); err != nil { if err := c.IsRemoveDirAllowed(fs, fsPath, virtualPath); err != nil {
return err return err
} }
return c.doRecursiveRemove(fs, fsPath, virtualPath, fi) return c.doRecursiveRemove(fs, fsPath, virtualPath, fi, 0)
} }
return c.RemoveFile(fs, fsPath, virtualPath, fi) return c.RemoveFile(fs, fsPath, virtualPath, fi)
} }
@ -626,43 +649,38 @@ func (c *BaseConnection) copyFile(virtualSourcePath, virtualTargetPath string, s
} }
func (c *BaseConnection) doRecursiveCopy(virtualSourcePath, virtualTargetPath string, srcInfo os.FileInfo, func (c *BaseConnection) doRecursiveCopy(virtualSourcePath, virtualTargetPath string, srcInfo os.FileInfo,
createTargetDir bool, createTargetDir bool, recursion int,
) error { ) error {
if srcInfo.IsDir() { if srcInfo.IsDir() {
if recursion >= util.MaxRecursion {
c.Log(logger.LevelError, "recursive copy failed, recursion too depth: %d", recursion)
return util.ErrRecursionTooDeep
}
recursion++
if createTargetDir { if createTargetDir {
if err := c.CreateDir(virtualTargetPath, false); err != nil { if err := c.CreateDir(virtualTargetPath, false); err != nil {
return fmt.Errorf("unable to create directory %q: %w", virtualTargetPath, err) return fmt.Errorf("unable to create directory %q: %w", virtualTargetPath, err)
} }
} }
entries, err := c.ListDir(virtualSourcePath) lister, err := c.ListDir(virtualSourcePath)
if err != nil { if err != nil {
return fmt.Errorf("unable to get contents for dir %q: %w", virtualSourcePath, err) return fmt.Errorf("unable to get lister for dir %q: %w", virtualSourcePath, err)
} }
for _, info := range entries { defer lister.Close()
sourcePath := path.Join(virtualSourcePath, info.Name())
targetPath := path.Join(virtualTargetPath, info.Name()) for {
targetInfo, err := c.DoStat(targetPath, 1, false) entries, err := lister.Next(vfs.ListerBatchSize)
if err == nil { finished := errors.Is(err, io.EOF)
if info.IsDir() && targetInfo.IsDir() { if err != nil && !finished {
c.Log(logger.LevelDebug, "target copy dir %q already exists", targetPath) return fmt.Errorf("unable to get contents for dir %q: %w", virtualSourcePath, err)
continue
}
} }
if err != nil && !c.IsNotExistError(err) { if err := c.recursiveCopyEntries(virtualSourcePath, virtualTargetPath, entries, recursion); err != nil {
return err return err
} }
if err := c.checkCopy(info, targetInfo, sourcePath, targetPath); err != nil { if finished {
return err return nil
}
if err := c.doRecursiveCopy(sourcePath, targetPath, info, true); err != nil {
if c.IsNotExistError(err) {
c.Log(logger.LevelInfo, "skipping copy for source path %q: %v", sourcePath, err)
continue
}
return err
} }
} }
return nil
} }
if !srcInfo.Mode().IsRegular() { if !srcInfo.Mode().IsRegular() {
c.Log(logger.LevelInfo, "skipping copy for non regular file %q", virtualSourcePath) c.Log(logger.LevelInfo, "skipping copy for non regular file %q", virtualSourcePath)
@ -672,6 +690,34 @@ func (c *BaseConnection) doRecursiveCopy(virtualSourcePath, virtualTargetPath st
return c.copyFile(virtualSourcePath, virtualTargetPath, srcInfo.Size()) return c.copyFile(virtualSourcePath, virtualTargetPath, srcInfo.Size())
} }
func (c *BaseConnection) recursiveCopyEntries(virtualSourcePath, virtualTargetPath string, entries []os.FileInfo, recursion int) error {
for _, info := range entries {
sourcePath := path.Join(virtualSourcePath, info.Name())
targetPath := path.Join(virtualTargetPath, info.Name())
targetInfo, err := c.DoStat(targetPath, 1, false)
if err == nil {
if info.IsDir() && targetInfo.IsDir() {
c.Log(logger.LevelDebug, "target copy dir %q already exists", targetPath)
continue
}
}
if err != nil && !c.IsNotExistError(err) {
return err
}
if err := c.checkCopy(info, targetInfo, sourcePath, targetPath); err != nil {
return err
}
if err := c.doRecursiveCopy(sourcePath, targetPath, info, true, recursion); err != nil {
if c.IsNotExistError(err) {
c.Log(logger.LevelInfo, "skipping copy for source path %q: %v", sourcePath, err)
continue
}
return err
}
}
return nil
}
// Copy virtualSourcePath to virtualTargetPath // Copy virtualSourcePath to virtualTargetPath
func (c *BaseConnection) Copy(virtualSourcePath, virtualTargetPath string) error { func (c *BaseConnection) Copy(virtualSourcePath, virtualTargetPath string) error {
copyFromSource := strings.HasSuffix(virtualSourcePath, "/") copyFromSource := strings.HasSuffix(virtualSourcePath, "/")
@ -717,7 +763,7 @@ func (c *BaseConnection) Copy(virtualSourcePath, virtualTargetPath string) error
defer close(done) defer close(done)
go keepConnectionAlive(c, done, 2*time.Minute) go keepConnectionAlive(c, done, 2*time.Minute)
return c.doRecursiveCopy(virtualSourcePath, destPath, srcInfo, createTargetDir) return c.doRecursiveCopy(virtualSourcePath, destPath, srcInfo, createTargetDir, 0)
} }
// Rename renames (moves) virtualSourcePath to virtualTargetPath // Rename renames (moves) virtualSourcePath to virtualTargetPath
@ -865,7 +911,8 @@ func (c *BaseConnection) doStatInternal(virtualPath string, mode int, checkFileP
convertResult bool, convertResult bool,
) (os.FileInfo, error) { ) (os.FileInfo, error) {
// for some vfs we don't create intermediary folders so we cannot simply check // for some vfs we don't create intermediary folders so we cannot simply check
// if virtualPath is a virtual folder // if virtualPath is a virtual folder. Allowing stat for hidden virtual folders
// is by purpose.
vfolders := c.User.GetVirtualFoldersInPath(path.Dir(virtualPath)) vfolders := c.User.GetVirtualFoldersInPath(path.Dir(virtualPath))
if _, ok := vfolders[virtualPath]; ok { if _, ok := vfolders[virtualPath]; ok {
return vfs.NewFileInfo(virtualPath, true, 0, time.Unix(0, 0), false), nil return vfs.NewFileInfo(virtualPath, true, 0, time.Unix(0, 0), false), nil
@ -1739,6 +1786,83 @@ func (c *BaseConnection) GetFsAndResolvedPath(virtualPath string) (vfs.Fs, strin
return fs, fsPath, nil return fs, fsPath, nil
} }
// DirListerAt defines a directory lister implementing the ListAt method.
type DirListerAt struct {
virtualPath string
user *dataprovider.User
info []os.FileInfo
mu sync.Mutex
lister vfs.DirLister
}
// Add adds the given os.FileInfo to the internal cache
func (l *DirListerAt) Add(fi os.FileInfo) {
l.mu.Lock()
defer l.mu.Unlock()
l.info = append(l.info, fi)
}
// ListAt implements sftp.ListerAt
func (l *DirListerAt) ListAt(f []os.FileInfo, _ int64) (int, error) {
l.mu.Lock()
defer l.mu.Unlock()
if len(f) == 0 {
return 0, errors.New("invalid ListAt destination, zero size")
}
if len(f) <= len(l.info) {
files := make([]os.FileInfo, 0, len(f))
for idx := len(l.info) - 1; idx >= 0; idx-- {
files = append(files, l.info[idx])
if len(files) == len(f) {
l.info = l.info[:idx]
n := copy(f, files)
return n, nil
}
}
}
limit := len(f) - len(l.info)
files, err := l.Next(limit)
n := copy(f, files)
return n, err
}
// Next reads the directory and returns a slice of up to n FileInfo values.
func (l *DirListerAt) Next(limit int) ([]os.FileInfo, error) {
for {
files, err := l.lister.Next(limit)
if err != nil && !errors.Is(err, io.EOF) {
return files, err
}
files = l.user.FilterListDir(files, l.virtualPath)
if len(l.info) > 0 {
for _, fi := range l.info {
files = util.PrependFileInfo(files, fi)
}
l.info = nil
}
if err != nil || len(files) > 0 {
return files, err
}
}
}
// Close closes the DirListerAt
func (l *DirListerAt) Close() error {
l.mu.Lock()
defer l.mu.Unlock()
return l.lister.Close()
}
func (l *DirListerAt) convertError(err error) error {
if errors.Is(err, io.EOF) {
return nil
}
return err
}
func getPermissionDeniedError(protocol string) error { func getPermissionDeniedError(protocol string) error {
switch protocol { switch protocol {
case ProtocolSFTP: case ProtocolSFTP:

View file

@ -17,10 +17,12 @@ package common
import ( import (
"errors" "errors"
"fmt" "fmt"
"io"
"os" "os"
"path" "path"
"path/filepath" "path/filepath"
"runtime" "runtime"
"strconv"
"testing" "testing"
"time" "time"
@ -601,8 +603,10 @@ func TestErrorResolvePath(t *testing.T) {
} }
conn := NewBaseConnection("", ProtocolSFTP, "", "", u) conn := NewBaseConnection("", ProtocolSFTP, "", "", u)
err := conn.doRecursiveRemoveDirEntry("/vpath", nil) err := conn.doRecursiveRemoveDirEntry("/vpath", nil, 0)
assert.Error(t, err) assert.Error(t, err)
err = conn.doRecursiveRemove(nil, "/fspath", "/vpath", vfs.NewFileInfo("vpath", true, 0, time.Now(), false), 2000)
assert.Error(t, err, util.ErrRecursionTooDeep)
err = conn.checkCopy(vfs.NewFileInfo("name", true, 0, time.Unix(0, 0), false), nil, "/source", "/target") err = conn.checkCopy(vfs.NewFileInfo("name", true, 0, time.Unix(0, 0), false), nil, "/source", "/target")
assert.Error(t, err) assert.Error(t, err)
sourceFile := filepath.Join(os.TempDir(), "f", "source") sourceFile := filepath.Join(os.TempDir(), "f", "source")
@ -700,26 +704,32 @@ func TestFilePatterns(t *testing.T) {
VirtualFolders: virtualFolders, VirtualFolders: virtualFolders,
} }
getFilteredInfo := func(dirContents []os.FileInfo, virtualPath string) []os.FileInfo {
result := user.FilterListDir(dirContents, virtualPath)
result = append(result, user.GetVirtualFoldersInfo(virtualPath)...)
return result
}
dirContents := []os.FileInfo{ dirContents := []os.FileInfo{
vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false),
vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false), vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false),
} }
// dirContents are modified in place, we need to redefine them each time // dirContents are modified in place, we need to redefine them each time
filtered := user.FilterListDir(dirContents, "/dir1") filtered := getFilteredInfo(dirContents, "/dir1")
assert.Len(t, filtered, 5) assert.Len(t, filtered, 5)
dirContents = []os.FileInfo{ dirContents = []os.FileInfo{
vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false),
vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false), vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false),
} }
filtered = user.FilterListDir(dirContents, "/dir1/vdir1") filtered = getFilteredInfo(dirContents, "/dir1/vdir1")
assert.Len(t, filtered, 2) assert.Len(t, filtered, 2)
dirContents = []os.FileInfo{ dirContents = []os.FileInfo{
vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false),
vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false), vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false),
} }
filtered = user.FilterListDir(dirContents, "/dir2/vdir2") filtered = getFilteredInfo(dirContents, "/dir2/vdir2")
require.Len(t, filtered, 1) require.Len(t, filtered, 1)
assert.Equal(t, "file1.jpg", filtered[0].Name()) assert.Equal(t, "file1.jpg", filtered[0].Name())
@ -727,7 +737,7 @@ func TestFilePatterns(t *testing.T) {
vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false),
vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false), vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false),
} }
filtered = user.FilterListDir(dirContents, "/dir2/vdir2/sub") filtered = getFilteredInfo(dirContents, "/dir2/vdir2/sub")
require.Len(t, filtered, 1) require.Len(t, filtered, 1)
assert.Equal(t, "file1.jpg", filtered[0].Name()) assert.Equal(t, "file1.jpg", filtered[0].Name())
@ -754,14 +764,14 @@ func TestFilePatterns(t *testing.T) {
vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false),
vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false), vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false),
} }
filtered = user.FilterListDir(dirContents, "/dir4") filtered = getFilteredInfo(dirContents, "/dir4")
require.Len(t, filtered, 0) require.Len(t, filtered, 0)
dirContents = []os.FileInfo{ dirContents = []os.FileInfo{
vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false),
vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false), vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false),
} }
filtered = user.FilterListDir(dirContents, "/dir4/vdir2/sub") filtered = getFilteredInfo(dirContents, "/dir4/vdir2/sub")
require.Len(t, filtered, 0) require.Len(t, filtered, 0)
dirContents = []os.FileInfo{ dirContents = []os.FileInfo{
@ -769,7 +779,7 @@ func TestFilePatterns(t *testing.T) {
vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false), vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false),
} }
filtered = user.FilterListDir(dirContents, "/dir2") filtered = getFilteredInfo(dirContents, "/dir2")
assert.Len(t, filtered, 2) assert.Len(t, filtered, 2)
dirContents = []os.FileInfo{ dirContents = []os.FileInfo{
@ -777,7 +787,7 @@ func TestFilePatterns(t *testing.T) {
vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false), vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false),
} }
filtered = user.FilterListDir(dirContents, "/dir4") filtered = getFilteredInfo(dirContents, "/dir4")
assert.Len(t, filtered, 0) assert.Len(t, filtered, 0)
dirContents = []os.FileInfo{ dirContents = []os.FileInfo{
@ -785,7 +795,7 @@ func TestFilePatterns(t *testing.T) {
vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false), vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false),
} }
filtered = user.FilterListDir(dirContents, "/dir4/sub") filtered = getFilteredInfo(dirContents, "/dir4/sub")
assert.Len(t, filtered, 0) assert.Len(t, filtered, 0)
dirContents = []os.FileInfo{ dirContents = []os.FileInfo{
@ -793,10 +803,10 @@ func TestFilePatterns(t *testing.T) {
vfs.NewFileInfo("vdir3.jpg", false, 123, time.Now(), false), vfs.NewFileInfo("vdir3.jpg", false, 123, time.Now(), false),
} }
filtered = user.FilterListDir(dirContents, "/dir1") filtered = getFilteredInfo(dirContents, "/dir1")
assert.Len(t, filtered, 5) assert.Len(t, filtered, 5)
filtered = user.FilterListDir(dirContents, "/dir2") filtered = getFilteredInfo(dirContents, "/dir2")
if assert.Len(t, filtered, 1) { if assert.Len(t, filtered, 1) {
assert.True(t, filtered[0].IsDir()) assert.True(t, filtered[0].IsDir())
} }
@ -806,14 +816,14 @@ func TestFilePatterns(t *testing.T) {
vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false),
vfs.NewFileInfo("vdir3.jpg", false, 123, time.Now(), false), vfs.NewFileInfo("vdir3.jpg", false, 123, time.Now(), false),
} }
filtered = user.FilterListDir(dirContents, "/dir1") filtered = getFilteredInfo(dirContents, "/dir1")
assert.Len(t, filtered, 2) assert.Len(t, filtered, 2)
dirContents = []os.FileInfo{ dirContents = []os.FileInfo{
vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false),
vfs.NewFileInfo("vdir3.jpg", false, 123, time.Now(), false), vfs.NewFileInfo("vdir3.jpg", false, 123, time.Now(), false),
} }
filtered = user.FilterListDir(dirContents, "/dir2") filtered = getFilteredInfo(dirContents, "/dir2")
if assert.Len(t, filtered, 1) { if assert.Len(t, filtered, 1) {
assert.False(t, filtered[0].IsDir()) assert.False(t, filtered[0].IsDir())
} }
@ -824,7 +834,7 @@ func TestFilePatterns(t *testing.T) {
vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false),
vfs.NewFileInfo("vdir3.jpg", false, 123, time.Now(), false), vfs.NewFileInfo("vdir3.jpg", false, 123, time.Now(), false),
} }
filtered = user.FilterListDir(dirContents, "/dir2") filtered = getFilteredInfo(dirContents, "/dir2")
if assert.Len(t, filtered, 2) { if assert.Len(t, filtered, 2) {
assert.False(t, filtered[0].IsDir()) assert.False(t, filtered[0].IsDir())
assert.False(t, filtered[1].IsDir()) assert.False(t, filtered[1].IsDir())
@ -832,9 +842,9 @@ func TestFilePatterns(t *testing.T) {
user.VirtualFolders = virtualFolders user.VirtualFolders = virtualFolders
user.Filters = filters user.Filters = filters
filtered = user.FilterListDir(nil, "/dir1") filtered = getFilteredInfo(nil, "/dir1")
assert.Len(t, filtered, 3) assert.Len(t, filtered, 3)
filtered = user.FilterListDir(nil, "/dir2") filtered = getFilteredInfo(nil, "/dir2")
assert.Len(t, filtered, 1) assert.Len(t, filtered, 1)
dirContents = []os.FileInfo{ dirContents = []os.FileInfo{
@ -843,7 +853,7 @@ func TestFilePatterns(t *testing.T) {
vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false),
vfs.NewFileInfo("vdir3.jpg", false, 456, time.Now(), false), vfs.NewFileInfo("vdir3.jpg", false, 456, time.Now(), false),
} }
filtered = user.FilterListDir(dirContents, "/dir2") filtered = getFilteredInfo(dirContents, "/dir2")
assert.Len(t, filtered, 2) assert.Len(t, filtered, 2)
user = dataprovider.User{ user = dataprovider.User{
@ -866,7 +876,7 @@ func TestFilePatterns(t *testing.T) {
vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false),
vfs.NewFileInfo("vdir3.jpg", false, 456, time.Now(), false), vfs.NewFileInfo("vdir3.jpg", false, 456, time.Now(), false),
} }
filtered = user.FilterListDir(dirContents, "/dir3") filtered = getFilteredInfo(dirContents, "/dir3")
assert.Len(t, filtered, 0) assert.Len(t, filtered, 0)
dirContents = nil dirContents = nil
@ -881,7 +891,7 @@ func TestFilePatterns(t *testing.T) {
dirContents = append(dirContents, vfs.NewFileInfo("ic35.*", false, 123, time.Now(), false)) dirContents = append(dirContents, vfs.NewFileInfo("ic35.*", false, 123, time.Now(), false))
dirContents = append(dirContents, vfs.NewFileInfo("file.jpg", false, 123, time.Now(), false)) dirContents = append(dirContents, vfs.NewFileInfo("file.jpg", false, 123, time.Now(), false))
filtered = user.FilterListDir(dirContents, "/dir3") filtered = getFilteredInfo(dirContents, "/dir3")
require.Len(t, filtered, 1) require.Len(t, filtered, 1)
assert.Equal(t, "ic35", filtered[0].Name()) assert.Equal(t, "ic35", filtered[0].Name())
@ -890,7 +900,7 @@ func TestFilePatterns(t *testing.T) {
vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false),
vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false),
} }
filtered = user.FilterListDir(dirContents, "/dir3/ic36") filtered = getFilteredInfo(dirContents, "/dir3/ic36")
require.Len(t, filtered, 0) require.Len(t, filtered, 0)
dirContents = []os.FileInfo{ dirContents = []os.FileInfo{
@ -898,7 +908,7 @@ func TestFilePatterns(t *testing.T) {
vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false),
vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false),
} }
filtered = user.FilterListDir(dirContents, "/dir3/ic35") filtered = getFilteredInfo(dirContents, "/dir3/ic35")
require.Len(t, filtered, 3) require.Len(t, filtered, 3)
dirContents = []os.FileInfo{ dirContents = []os.FileInfo{
@ -906,7 +916,7 @@ func TestFilePatterns(t *testing.T) {
vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false),
vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false),
} }
filtered = user.FilterListDir(dirContents, "/dir3/ic35/sub") filtered = getFilteredInfo(dirContents, "/dir3/ic35/sub")
require.Len(t, filtered, 3) require.Len(t, filtered, 3)
res, _ = user.IsFileAllowed("/dir3/file.txt") res, _ = user.IsFileAllowed("/dir3/file.txt")
@ -930,7 +940,7 @@ func TestFilePatterns(t *testing.T) {
vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false),
vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false),
} }
filtered = user.FilterListDir(dirContents, "/dir3/ic35/sub") filtered = getFilteredInfo(dirContents, "/dir3/ic35/sub")
require.Len(t, filtered, 3) require.Len(t, filtered, 3)
user.Filters.FilePatterns = append(user.Filters.FilePatterns, sdk.PatternsFilter{ user.Filters.FilePatterns = append(user.Filters.FilePatterns, sdk.PatternsFilter{
@ -949,7 +959,7 @@ func TestFilePatterns(t *testing.T) {
vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false),
vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false),
} }
filtered = user.FilterListDir(dirContents, "/dir3/ic35/sub1") filtered = getFilteredInfo(dirContents, "/dir3/ic35/sub1")
require.Len(t, filtered, 3) require.Len(t, filtered, 3)
dirContents = []os.FileInfo{ dirContents = []os.FileInfo{
@ -957,7 +967,7 @@ func TestFilePatterns(t *testing.T) {
vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false),
vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false),
} }
filtered = user.FilterListDir(dirContents, "/dir3/ic35/sub2") filtered = getFilteredInfo(dirContents, "/dir3/ic35/sub2")
require.Len(t, filtered, 2) require.Len(t, filtered, 2)
dirContents = []os.FileInfo{ dirContents = []os.FileInfo{
@ -965,7 +975,7 @@ func TestFilePatterns(t *testing.T) {
vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false),
vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false),
} }
filtered = user.FilterListDir(dirContents, "/dir3/ic35/sub2/sub1") filtered = getFilteredInfo(dirContents, "/dir3/ic35/sub2/sub1")
require.Len(t, filtered, 2) require.Len(t, filtered, 2)
res, _ = user.IsFileAllowed("/dir3/ic35/file.jpg") res, _ = user.IsFileAllowed("/dir3/ic35/file.jpg")
@ -1023,7 +1033,7 @@ func TestFilePatterns(t *testing.T) {
vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false),
vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false),
} }
filtered = user.FilterListDir(dirContents, "/dir3/ic35") filtered = getFilteredInfo(dirContents, "/dir3/ic35")
require.Len(t, filtered, 1) require.Len(t, filtered, 1)
dirContents = []os.FileInfo{ dirContents = []os.FileInfo{
@ -1031,6 +1041,116 @@ func TestFilePatterns(t *testing.T) {
vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false),
vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false),
} }
filtered = user.FilterListDir(dirContents, "/dir3/ic35/abc") filtered = getFilteredInfo(dirContents, "/dir3/ic35/abc")
require.Len(t, filtered, 1) require.Len(t, filtered, 1)
} }
func TestListerAt(t *testing.T) {
dir := t.TempDir()
user := dataprovider.User{
BaseUser: sdk.BaseUser{
Username: "u",
Password: "p",
HomeDir: dir,
Status: 1,
Permissions: map[string][]string{
"/": {"*"},
},
},
}
conn := NewBaseConnection(xid.New().String(), ProtocolSFTP, "", "", user)
lister, err := conn.ListDir("/")
require.NoError(t, err)
files, err := lister.Next(1)
require.ErrorIs(t, err, io.EOF)
require.Len(t, files, 0)
err = lister.Close()
require.NoError(t, err)
conn.User.VirtualFolders = []vfs.VirtualFolder{
{
VirtualPath: "p1",
},
{
VirtualPath: "p2",
},
{
VirtualPath: "p3",
},
}
lister, err = conn.ListDir("/")
require.NoError(t, err)
files, err = lister.Next(2)
// virtual directories exceeds the limit
require.ErrorIs(t, err, io.EOF)
require.Len(t, files, 3)
files, err = lister.Next(2)
require.ErrorIs(t, err, io.EOF)
require.Len(t, files, 0)
_, err = lister.Next(-1)
require.ErrorContains(t, err, "invalid limit")
err = lister.Close()
require.NoError(t, err)
lister, err = conn.ListDir("/")
require.NoError(t, err)
_, err = lister.ListAt(nil, 0)
require.ErrorContains(t, err, "zero size")
err = lister.Close()
require.NoError(t, err)
for i := 0; i < 100; i++ {
f, err := os.Create(filepath.Join(dir, strconv.Itoa(i)))
require.NoError(t, err)
err = f.Close()
require.NoError(t, err)
}
lister, err = conn.ListDir("/")
require.NoError(t, err)
files = make([]os.FileInfo, 18)
n, err := lister.ListAt(files, 0)
require.NoError(t, err)
require.Equal(t, 18, n)
n, err = lister.ListAt(files, 0)
require.NoError(t, err)
require.Equal(t, 18, n)
files = make([]os.FileInfo, 100)
n, err = lister.ListAt(files, 0)
require.NoError(t, err)
require.Equal(t, 64+3, n)
n, err = lister.ListAt(files, 0)
require.ErrorIs(t, err, io.EOF)
require.Equal(t, 0, n)
n, err = lister.ListAt(files, 0)
require.ErrorIs(t, err, io.EOF)
require.Equal(t, 0, n)
err = lister.Close()
require.NoError(t, err)
n, err = lister.ListAt(files, 0)
require.Error(t, err)
require.NotErrorIs(t, err, io.EOF)
require.Equal(t, 0, n)
lister, err = conn.ListDir("/")
require.NoError(t, err)
lister.Add(vfs.NewFileInfo("..", true, 0, time.Unix(0, 0), false))
lister.Add(vfs.NewFileInfo(".", true, 0, time.Unix(0, 0), false))
files = make([]os.FileInfo, 1)
n, err = lister.ListAt(files, 0)
require.NoError(t, err)
require.Equal(t, 1, n)
assert.Equal(t, ".", files[0].Name())
files = make([]os.FileInfo, 2)
n, err = lister.ListAt(files, 0)
require.NoError(t, err)
require.Equal(t, 2, n)
assert.Equal(t, "..", files[0].Name())
assert.Equal(t, "p3", files[1].Name())
files = make([]os.FileInfo, 200)
n, err = lister.ListAt(files, 0)
require.NoError(t, err)
require.Equal(t, 102, n)
assert.Equal(t, "p2", files[0].Name())
assert.Equal(t, "p1", files[1].Name())
err = lister.Close()
require.NoError(t, err)
}

View file

@ -18,7 +18,9 @@ import (
"bytes" "bytes"
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io"
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
@ -37,6 +39,7 @@ import (
"github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/smtp" "github.com/drakkan/sftpgo/v2/internal/smtp"
"github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/util"
"github.com/drakkan/sftpgo/v2/internal/vfs"
) )
// RetentionCheckNotification defines the supported notification methods for a retention check result // RetentionCheckNotification defines the supported notification methods for a retention check result
@ -226,8 +229,17 @@ func (c *RetentionCheck) removeFile(virtualPath string, info os.FileInfo) error
return c.conn.RemoveFile(fs, fsPath, virtualPath, info) return c.conn.RemoveFile(fs, fsPath, virtualPath, info)
} }
func (c *RetentionCheck) cleanupFolder(folderPath string) error { func (c *RetentionCheck) hasCleanupPerms(folderPath string) bool {
deleteFilesPerms := []string{dataprovider.PermDelete, dataprovider.PermDeleteFiles} if !c.conn.User.HasPerm(dataprovider.PermListItems, folderPath) {
return false
}
if !c.conn.User.HasAnyPerm([]string{dataprovider.PermDelete, dataprovider.PermDeleteFiles}, folderPath) {
return false
}
return true
}
func (c *RetentionCheck) cleanupFolder(folderPath string, recursion int) error {
startTime := time.Now() startTime := time.Now()
result := folderRetentionCheckResult{ result := folderRetentionCheckResult{
Path: folderPath, Path: folderPath,
@ -235,7 +247,15 @@ func (c *RetentionCheck) cleanupFolder(folderPath string) error {
defer func() { defer func() {
c.results = append(c.results, result) c.results = append(c.results, result)
}() }()
if !c.conn.User.HasPerm(dataprovider.PermListItems, folderPath) || !c.conn.User.HasAnyPerm(deleteFilesPerms, folderPath) { if recursion >= util.MaxRecursion {
result.Elapsed = time.Since(startTime)
result.Info = "data retention check skipped: recursion too deep"
c.conn.Log(logger.LevelError, "data retention check skipped, recursion too depth for %q: %d",
folderPath, recursion)
return util.ErrRecursionTooDeep
}
recursion++
if !c.hasCleanupPerms(folderPath) {
result.Elapsed = time.Since(startTime) result.Elapsed = time.Since(startTime)
result.Info = "data retention check skipped: no permissions" result.Info = "data retention check skipped: no permissions"
c.conn.Log(logger.LevelInfo, "user %q does not have permissions to check retention on %q, retention check skipped", c.conn.Log(logger.LevelInfo, "user %q does not have permissions to check retention on %q, retention check skipped",
@ -259,7 +279,7 @@ func (c *RetentionCheck) cleanupFolder(folderPath string) error {
} }
c.conn.Log(logger.LevelDebug, "start retention check for folder %q, retention: %v hours, delete empty dirs? %v, ignore user perms? %v", c.conn.Log(logger.LevelDebug, "start retention check for folder %q, retention: %v hours, delete empty dirs? %v, ignore user perms? %v",
folderPath, folderRetention.Retention, folderRetention.DeleteEmptyDirs, folderRetention.IgnoreUserPermissions) folderPath, folderRetention.Retention, folderRetention.DeleteEmptyDirs, folderRetention.IgnoreUserPermissions)
files, err := c.conn.ListDir(folderPath) lister, err := c.conn.ListDir(folderPath)
if err != nil { if err != nil {
result.Elapsed = time.Since(startTime) result.Elapsed = time.Since(startTime)
if err == c.conn.GetNotExistError() { if err == c.conn.GetNotExistError() {
@ -267,40 +287,54 @@ func (c *RetentionCheck) cleanupFolder(folderPath string) error {
c.conn.Log(logger.LevelDebug, "folder %q does not exist, retention check skipped", folderPath) c.conn.Log(logger.LevelDebug, "folder %q does not exist, retention check skipped", folderPath)
return nil return nil
} }
result.Error = fmt.Sprintf("unable to list directory %q", folderPath) result.Error = fmt.Sprintf("unable to get lister for directory %q", folderPath)
c.conn.Log(logger.LevelError, result.Error) c.conn.Log(logger.LevelError, result.Error)
return err return err
} }
for _, info := range files { defer lister.Close()
virtualPath := path.Join(folderPath, info.Name())
if info.IsDir() { for {
if err := c.cleanupFolder(virtualPath); err != nil { files, err := lister.Next(vfs.ListerBatchSize)
result.Elapsed = time.Since(startTime) finished := errors.Is(err, io.EOF)
result.Error = fmt.Sprintf("unable to check folder: %v", err) if err := lister.convertError(err); err != nil {
c.conn.Log(logger.LevelError, "unable to cleanup folder %q: %v", virtualPath, err) result.Elapsed = time.Since(startTime)
return err result.Error = fmt.Sprintf("unable to list directory %q", folderPath)
} c.conn.Log(logger.LevelError, "unable to list dir %q: %v", folderPath, err)
} else { return err
retentionTime := info.ModTime().Add(time.Duration(folderRetention.Retention) * time.Hour) }
if retentionTime.Before(time.Now()) { for _, info := range files {
if err := c.removeFile(virtualPath, info); err != nil { virtualPath := path.Join(folderPath, info.Name())
if info.IsDir() {
if err := c.cleanupFolder(virtualPath, recursion); err != nil {
result.Elapsed = time.Since(startTime) result.Elapsed = time.Since(startTime)
result.Error = fmt.Sprintf("unable to remove file %q: %v", virtualPath, err) result.Error = fmt.Sprintf("unable to check folder: %v", err)
c.conn.Log(logger.LevelError, "unable to remove file %q, retention %v: %v", c.conn.Log(logger.LevelError, "unable to cleanup folder %q: %v", virtualPath, err)
virtualPath, retentionTime, err)
return err return err
} }
c.conn.Log(logger.LevelDebug, "removed file %q, modification time: %v, retention: %v hours, retention time: %v", } else {
virtualPath, info.ModTime(), folderRetention.Retention, retentionTime) retentionTime := info.ModTime().Add(time.Duration(folderRetention.Retention) * time.Hour)
result.DeletedFiles++ if retentionTime.Before(time.Now()) {
result.DeletedSize += info.Size() if err := c.removeFile(virtualPath, info); err != nil {
result.Elapsed = time.Since(startTime)
result.Error = fmt.Sprintf("unable to remove file %q: %v", virtualPath, err)
c.conn.Log(logger.LevelError, "unable to remove file %q, retention %v: %v",
virtualPath, retentionTime, err)
return err
}
c.conn.Log(logger.LevelDebug, "removed file %q, modification time: %v, retention: %v hours, retention time: %v",
virtualPath, info.ModTime(), folderRetention.Retention, retentionTime)
result.DeletedFiles++
result.DeletedSize += info.Size()
}
} }
} }
if finished {
break
}
} }
if folderRetention.DeleteEmptyDirs { lister.Close()
c.checkEmptyDirRemoval(folderPath) c.checkEmptyDirRemoval(folderPath, folderRetention.DeleteEmptyDirs)
}
result.Elapsed = time.Since(startTime) result.Elapsed = time.Since(startTime)
c.conn.Log(logger.LevelDebug, "retention check completed for folder %q, deleted files: %v, deleted size: %v bytes", c.conn.Log(logger.LevelDebug, "retention check completed for folder %q, deleted files: %v, deleted size: %v bytes",
folderPath, result.DeletedFiles, result.DeletedSize) folderPath, result.DeletedFiles, result.DeletedSize)
@ -308,8 +342,8 @@ func (c *RetentionCheck) cleanupFolder(folderPath string) error {
return nil return nil
} }
func (c *RetentionCheck) checkEmptyDirRemoval(folderPath string) { func (c *RetentionCheck) checkEmptyDirRemoval(folderPath string, checkVal bool) {
if folderPath == "/" { if folderPath == "/" || !checkVal {
return return
} }
for _, folder := range c.Folders { for _, folder := range c.Folders {
@ -322,10 +356,14 @@ func (c *RetentionCheck) checkEmptyDirRemoval(folderPath string) {
dataprovider.PermDeleteDirs, dataprovider.PermDeleteDirs,
}, path.Dir(folderPath), }, path.Dir(folderPath),
) { ) {
files, err := c.conn.ListDir(folderPath) lister, err := c.conn.ListDir(folderPath)
if err == nil && len(files) == 0 { if err == nil {
err = c.conn.RemoveDir(folderPath) files, err := lister.Next(1)
c.conn.Log(logger.LevelDebug, "tried to remove empty dir %q, error: %v", folderPath, err) lister.Close()
if len(files) == 0 && errors.Is(err, io.EOF) {
err = c.conn.RemoveDir(folderPath)
c.conn.Log(logger.LevelDebug, "tried to remove empty dir %q, error: %v", folderPath, err)
}
} }
} }
} }
@ -339,7 +377,7 @@ func (c *RetentionCheck) Start() error {
startTime := time.Now() startTime := time.Now()
for _, folder := range c.Folders { for _, folder := range c.Folders {
if folder.Retention > 0 { if folder.Retention > 0 {
if err := c.cleanupFolder(folder.Path); err != nil { if err := c.cleanupFolder(folder.Path, 0); err != nil {
c.conn.Log(logger.LevelError, "retention check failed, unable to cleanup folder %q", folder.Path) c.conn.Log(logger.LevelError, "retention check failed, unable to cleanup folder %q", folder.Path)
c.sendNotifications(time.Since(startTime), err) c.sendNotifications(time.Since(startTime), err)
return err return err

View file

@ -28,6 +28,7 @@ import (
"github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/smtp" "github.com/drakkan/sftpgo/v2/internal/smtp"
"github.com/drakkan/sftpgo/v2/internal/util"
) )
func TestRetentionValidation(t *testing.T) { func TestRetentionValidation(t *testing.T) {
@ -272,7 +273,9 @@ func TestRetentionPermissionsAndGetFolder(t *testing.T) {
conn.SetProtocol(ProtocolDataRetention) conn.SetProtocol(ProtocolDataRetention)
conn.ID = fmt.Sprintf("data_retention_%v", user.Username) conn.ID = fmt.Sprintf("data_retention_%v", user.Username)
check.conn = conn check.conn = conn
assert.False(t, check.hasCleanupPerms(check.Folders[2].Path))
check.updateUserPermissions() check.updateUserPermissions()
assert.True(t, check.hasCleanupPerms(check.Folders[2].Path))
assert.Equal(t, []string{dataprovider.PermListItems, dataprovider.PermDelete}, conn.User.Permissions["/"]) assert.Equal(t, []string{dataprovider.PermListItems, dataprovider.PermDelete}, conn.User.Permissions["/"])
assert.Equal(t, []string{dataprovider.PermListItems}, conn.User.Permissions["/dir1"]) assert.Equal(t, []string{dataprovider.PermListItems}, conn.User.Permissions["/dir1"])
assert.Equal(t, []string{dataprovider.PermAny}, conn.User.Permissions["/dir2"]) assert.Equal(t, []string{dataprovider.PermAny}, conn.User.Permissions["/dir2"])
@ -391,8 +394,11 @@ func TestCleanupErrors(t *testing.T) {
err := check.removeFile("missing file", nil) err := check.removeFile("missing file", nil)
assert.Error(t, err) assert.Error(t, err)
err = check.cleanupFolder("/") err = check.cleanupFolder("/", 0)
assert.Error(t, err) assert.Error(t, err)
err = check.cleanupFolder("/", 1000)
assert.ErrorIs(t, err, util.ErrRecursionTooDeep)
assert.True(t, RetentionChecks.remove(user.Username)) assert.True(t, RetentionChecks.remove(user.Username))
} }

View file

@ -988,11 +988,16 @@ func getFileWriter(conn *BaseConnection, virtualPath string, expectedSize int64)
return w, numFiles, truncatedSize, cancelFn, nil return w, numFiles, truncatedSize, cancelFn, nil
} }
func addZipEntry(wr *zipWriterWrapper, conn *BaseConnection, entryPath, baseDir string) error { func addZipEntry(wr *zipWriterWrapper, conn *BaseConnection, entryPath, baseDir string, recursion int) error {
if entryPath == wr.Name { if entryPath == wr.Name {
// skip the archive itself // skip the archive itself
return nil return nil
} }
if recursion >= util.MaxRecursion {
eventManagerLog(logger.LevelError, "unable to add zip entry %q, recursion too deep: %v", entryPath, recursion)
return util.ErrRecursionTooDeep
}
recursion++
info, err := conn.DoStat(entryPath, 1, false) info, err := conn.DoStat(entryPath, 1, false)
if err != nil { if err != nil {
eventManagerLog(logger.LevelError, "unable to add zip entry %q, stat error: %v", entryPath, err) eventManagerLog(logger.LevelError, "unable to add zip entry %q, stat error: %v", entryPath, err)
@ -1018,25 +1023,42 @@ func addZipEntry(wr *zipWriterWrapper, conn *BaseConnection, entryPath, baseDir
eventManagerLog(logger.LevelError, "unable to create zip entry %q: %v", entryPath, err) eventManagerLog(logger.LevelError, "unable to create zip entry %q: %v", entryPath, err)
return fmt.Errorf("unable to create zip entry %q: %w", entryPath, err) return fmt.Errorf("unable to create zip entry %q: %w", entryPath, err)
} }
contents, err := conn.ListDir(entryPath) lister, err := conn.ListDir(entryPath)
if err != nil { if err != nil {
eventManagerLog(logger.LevelError, "unable to add zip entry %q, read dir error: %v", entryPath, err) eventManagerLog(logger.LevelError, "unable to add zip entry %q, get dir lister error: %v", entryPath, err)
return fmt.Errorf("unable to add zip entry %q: %w", entryPath, err) return fmt.Errorf("unable to add zip entry %q: %w", entryPath, err)
} }
for _, info := range contents { defer lister.Close()
fullPath := util.CleanPath(path.Join(entryPath, info.Name()))
if err := addZipEntry(wr, conn, fullPath, baseDir); err != nil { for {
eventManagerLog(logger.LevelError, "unable to add zip entry: %v", err) contents, err := lister.Next(vfs.ListerBatchSize)
return err finished := errors.Is(err, io.EOF)
if err := lister.convertError(err); err != nil {
eventManagerLog(logger.LevelError, "unable to add zip entry %q, read dir error: %v", entryPath, err)
return fmt.Errorf("unable to add zip entry %q: %w", entryPath, err)
}
for _, info := range contents {
fullPath := util.CleanPath(path.Join(entryPath, info.Name()))
if err := addZipEntry(wr, conn, fullPath, baseDir, recursion); err != nil {
eventManagerLog(logger.LevelError, "unable to add zip entry: %v", err)
return err
}
}
if finished {
return nil
} }
} }
return nil
} }
if !info.Mode().IsRegular() { if !info.Mode().IsRegular() {
// we only allow regular files // we only allow regular files
eventManagerLog(logger.LevelInfo, "skipping zip entry for non regular file %q", entryPath) eventManagerLog(logger.LevelInfo, "skipping zip entry for non regular file %q", entryPath)
return nil return nil
} }
return addFileToZip(wr, conn, entryPath, entryName, info.ModTime())
}
func addFileToZip(wr *zipWriterWrapper, conn *BaseConnection, entryPath, entryName string, modTime time.Time) error {
reader, cancelFn, err := getFileReader(conn, entryPath) reader, cancelFn, err := getFileReader(conn, entryPath)
if err != nil { if err != nil {
eventManagerLog(logger.LevelError, "unable to add zip entry %q, cannot open file: %v", entryPath, err) eventManagerLog(logger.LevelError, "unable to add zip entry %q, cannot open file: %v", entryPath, err)
@ -1048,7 +1070,7 @@ func addZipEntry(wr *zipWriterWrapper, conn *BaseConnection, entryPath, baseDir
f, err := wr.Writer.CreateHeader(&zip.FileHeader{ f, err := wr.Writer.CreateHeader(&zip.FileHeader{
Name: entryName, Name: entryName,
Method: zip.Deflate, Method: zip.Deflate,
Modified: info.ModTime(), Modified: modTime,
}) })
if err != nil { if err != nil {
eventManagerLog(logger.LevelError, "unable to create zip entry %q: %v", entryPath, err) eventManagerLog(logger.LevelError, "unable to create zip entry %q: %v", entryPath, err)
@ -1890,18 +1912,28 @@ func getArchiveBaseDir(paths []string) string {
func getSizeForPath(conn *BaseConnection, p string, info os.FileInfo) (int64, error) { func getSizeForPath(conn *BaseConnection, p string, info os.FileInfo) (int64, error) {
if info.IsDir() { if info.IsDir() {
var dirSize int64 var dirSize int64
entries, err := conn.ListDir(p) lister, err := conn.ListDir(p)
if err != nil { if err != nil {
return 0, err return 0, err
} }
for _, entry := range entries { defer lister.Close()
size, err := getSizeForPath(conn, path.Join(p, entry.Name()), entry) for {
if err != nil { entries, err := lister.Next(vfs.ListerBatchSize)
finished := errors.Is(err, io.EOF)
if err != nil && !finished {
return 0, err return 0, err
} }
dirSize += size for _, entry := range entries {
size, err := getSizeForPath(conn, path.Join(p, entry.Name()), entry)
if err != nil {
return 0, err
}
dirSize += size
}
if finished {
return dirSize, nil
}
} }
return dirSize, nil
} }
if info.Mode().IsRegular() { if info.Mode().IsRegular() {
return info.Size(), nil return info.Size(), nil
@ -1978,7 +2010,7 @@ func executeCompressFsActionForUser(c dataprovider.EventActionFsCompress, replac
} }
startTime := time.Now() startTime := time.Now()
for _, item := range paths { for _, item := range paths {
if err := addZipEntry(zipWriter, conn, item, baseDir); err != nil { if err := addZipEntry(zipWriter, conn, item, baseDir, 0); err != nil {
closeWriterAndUpdateQuota(writer, conn, name, "", numFiles, truncatedSize, err, operationUpload, startTime) //nolint:errcheck closeWriterAndUpdateQuota(writer, conn, name, "", numFiles, truncatedSize, err, operationUpload, startTime) //nolint:errcheck
return err return err
} }

View file

@ -1835,7 +1835,7 @@ func TestFilesystemActionErrors(t *testing.T) {
Writer: zip.NewWriter(bytes.NewBuffer(nil)), Writer: zip.NewWriter(bytes.NewBuffer(nil)),
Entries: map[string]bool{}, Entries: map[string]bool{},
} }
err = addZipEntry(wr, conn, "/adir/sub/f.dat", "/adir/sub/sub") err = addZipEntry(wr, conn, "/adir/sub/f.dat", "/adir/sub/sub", 0)
assert.Error(t, err) assert.Error(t, err)
assert.Contains(t, getErrorString(err), "is outside base dir") assert.Contains(t, getErrorString(err), "is outside base dir")
} }

View file

@ -131,9 +131,9 @@ func getDefenderHostQuery() string {
sqlTableDefenderHosts, sqlPlaceholders[0], sqlPlaceholders[1]) sqlTableDefenderHosts, sqlPlaceholders[0], sqlPlaceholders[1])
} }
func getDefenderEventsQuery(hostIDS []int64) string { func getDefenderEventsQuery(hostIDs []int64) string {
var sb strings.Builder var sb strings.Builder
for _, hID := range hostIDS { for _, hID := range hostIDs {
if sb.Len() == 0 { if sb.Len() == 0 {
sb.WriteString("(") sb.WriteString("(")
} else { } else {

View file

@ -676,8 +676,7 @@ func (u *User) GetVirtualFoldersInPath(virtualPath string) map[string]bool {
result := make(map[string]bool) result := make(map[string]bool)
for idx := range u.VirtualFolders { for idx := range u.VirtualFolders {
v := &u.VirtualFolders[idx] dirsForPath := util.GetDirsForVirtualPath(u.VirtualFolders[idx].VirtualPath)
dirsForPath := util.GetDirsForVirtualPath(v.VirtualPath)
for index := range dirsForPath { for index := range dirsForPath {
d := dirsForPath[index] d := dirsForPath[index]
if d == "/" { if d == "/" {
@ -716,13 +715,34 @@ func (u *User) hasVirtualDirs() bool {
return numFolders > 0 return numFolders > 0
} }
// FilterListDir adds virtual folders and remove hidden items from the given files list // GetVirtualFoldersInfo returns []os.FileInfo for virtual folders
func (u *User) GetVirtualFoldersInfo(virtualPath string) []os.FileInfo {
filter := u.getPatternsFilterForPath(virtualPath)
if !u.hasVirtualDirs() && filter.DenyPolicy != sdk.DenyPolicyHide {
return nil
}
vdirs := u.GetVirtualFoldersInPath(virtualPath)
result := make([]os.FileInfo, 0, len(vdirs))
for dir := range u.GetVirtualFoldersInPath(virtualPath) {
dirName := path.Base(dir)
if filter.DenyPolicy == sdk.DenyPolicyHide {
if !filter.CheckAllowed(dirName) {
continue
}
}
result = append(result, vfs.NewFileInfo(dirName, true, 0, time.Unix(0, 0), false))
}
return result
}
// FilterListDir removes hidden items from the given files list
func (u *User) FilterListDir(dirContents []os.FileInfo, virtualPath string) []os.FileInfo { func (u *User) FilterListDir(dirContents []os.FileInfo, virtualPath string) []os.FileInfo {
filter := u.getPatternsFilterForPath(virtualPath) filter := u.getPatternsFilterForPath(virtualPath)
if !u.hasVirtualDirs() && filter.DenyPolicy != sdk.DenyPolicyHide { if !u.hasVirtualDirs() && filter.DenyPolicy != sdk.DenyPolicyHide {
return dirContents return dirContents
} }
vdirs := make(map[string]bool) vdirs := make(map[string]bool)
for dir := range u.GetVirtualFoldersInPath(virtualPath) { for dir := range u.GetVirtualFoldersInPath(virtualPath) {
dirName := path.Base(dir) dirName := path.Base(dir)
@ -735,36 +755,24 @@ func (u *User) FilterListDir(dirContents []os.FileInfo, virtualPath string) []os
} }
validIdx := 0 validIdx := 0
for index, fi := range dirContents { for idx := range dirContents {
for dir := range vdirs { fi := dirContents[idx]
if fi.Name() == dir {
if !fi.IsDir() { if fi.Name() != "." && fi.Name() != ".." {
fi = vfs.NewFileInfo(dir, true, 0, time.Unix(0, 0), false) if _, ok := vdirs[fi.Name()]; ok {
dirContents[index] = fi continue
}
if filter.DenyPolicy == sdk.DenyPolicyHide {
if !filter.CheckAllowed(fi.Name()) {
continue
} }
delete(vdirs, dir)
}
}
if filter.DenyPolicy == sdk.DenyPolicyHide {
if filter.CheckAllowed(fi.Name()) {
dirContents[validIdx] = fi
validIdx++
} }
} }
dirContents[validIdx] = fi
validIdx++
} }
if filter.DenyPolicy == sdk.DenyPolicyHide { return dirContents[:validIdx]
for idx := validIdx; idx < len(dirContents); idx++ {
dirContents[idx] = nil
}
dirContents = dirContents[:validIdx]
}
for dir := range vdirs {
fi := vfs.NewFileInfo(dir, true, 0, time.Unix(0, 0), false)
dirContents = append(dirContents, fi)
}
return dirContents
} }
// IsMappedPath returns true if the specified filesystem path has a virtual folder mapping. // IsMappedPath returns true if the specified filesystem path has a virtual folder mapping.

View file

@ -2544,14 +2544,14 @@ func TestRename(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
err = client.MakeDir(path.Join(otherDir, testDir)) err = client.MakeDir(path.Join(otherDir, testDir))
assert.NoError(t, err) assert.NoError(t, err)
code, response, err := client.SendCustomCommand(fmt.Sprintf("SITE CHMOD 0001 %v", otherDir)) code, response, err := client.SendCommand(fmt.Sprintf("SITE CHMOD 0001 %v", otherDir))
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, ftp.StatusCommandOK, code) assert.Equal(t, ftp.StatusCommandOK, code)
assert.Equal(t, "SITE CHMOD command successful", response) assert.Equal(t, "SITE CHMOD command successful", response)
err = client.Rename(testDir, path.Join(otherDir, testDir)) err = client.Rename(testDir, path.Join(otherDir, testDir))
assert.Error(t, err) assert.Error(t, err)
code, response, err = client.SendCustomCommand(fmt.Sprintf("SITE CHMOD 755 %v", otherDir)) code, response, err = client.SendCommand(fmt.Sprintf("SITE CHMOD 755 %v", otherDir))
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, ftp.StatusCommandOK, code) assert.Equal(t, ftp.StatusCommandOK, code)
assert.Equal(t, "SITE CHMOD command successful", response) assert.Equal(t, "SITE CHMOD command successful", response)
@ -2611,7 +2611,7 @@ func TestSymlink(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0)
assert.NoError(t, err) assert.NoError(t, err)
code, _, err := client.SendCustomCommand(fmt.Sprintf("SITE SYMLINK %v %v", testFileName, testFileName+".link")) code, _, err := client.SendCommand(fmt.Sprintf("SITE SYMLINK %v %v", testFileName, testFileName+".link"))
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, ftp.StatusCommandOK, code) assert.Equal(t, ftp.StatusCommandOK, code)
@ -2622,15 +2622,15 @@ func TestSymlink(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
err = client.MakeDir(path.Join(otherDir, testDir)) err = client.MakeDir(path.Join(otherDir, testDir))
assert.NoError(t, err) assert.NoError(t, err)
code, response, err := client.SendCustomCommand(fmt.Sprintf("SITE CHMOD 0001 %v", otherDir)) code, response, err := client.SendCommand(fmt.Sprintf("SITE CHMOD 0001 %v", otherDir))
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, ftp.StatusCommandOK, code) assert.Equal(t, ftp.StatusCommandOK, code)
assert.Equal(t, "SITE CHMOD command successful", response) assert.Equal(t, "SITE CHMOD command successful", response)
code, _, err = client.SendCustomCommand(fmt.Sprintf("SITE SYMLINK %v %v", testDir, path.Join(otherDir, testDir))) code, _, err = client.SendCommand(fmt.Sprintf("SITE SYMLINK %v %v", testDir, path.Join(otherDir, testDir)))
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, ftp.StatusFileUnavailable, code) assert.Equal(t, ftp.StatusFileUnavailable, code)
code, response, err = client.SendCustomCommand(fmt.Sprintf("SITE CHMOD 755 %v", otherDir)) code, response, err = client.SendCommand(fmt.Sprintf("SITE CHMOD 755 %v", otherDir))
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, ftp.StatusCommandOK, code) assert.Equal(t, ftp.StatusCommandOK, code)
assert.Equal(t, "SITE CHMOD command successful", response) assert.Equal(t, "SITE CHMOD command successful", response)
@ -2860,17 +2860,17 @@ func TestAllocateAvailable(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
client, err := getFTPClient(user, false, nil) client, err := getFTPClient(user, false, nil)
if assert.NoError(t, err) { if assert.NoError(t, err) {
code, response, err := client.SendCustomCommand("allo 2000000") code, response, err := client.SendCommand("allo 2000000")
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, ftp.StatusCommandOK, code) assert.Equal(t, ftp.StatusCommandOK, code)
assert.Equal(t, "Done !", response) assert.Equal(t, "Done !", response)
code, response, err = client.SendCustomCommand("AVBL /vdir") code, response, err = client.SendCommand("AVBL /vdir")
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, ftp.StatusFile, code) assert.Equal(t, ftp.StatusFile, code)
assert.Equal(t, "110", response) assert.Equal(t, "110", response)
code, _, err = client.SendCustomCommand("AVBL") code, _, err = client.SendCommand("AVBL")
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, ftp.StatusFile, code) assert.Equal(t, ftp.StatusFile, code)
@ -2886,7 +2886,7 @@ func TestAllocateAvailable(t *testing.T) {
testFileSize := user.QuotaSize - 1 testFileSize := user.QuotaSize - 1
err = createTestFile(testFilePath, testFileSize) err = createTestFile(testFilePath, testFileSize)
assert.NoError(t, err) assert.NoError(t, err)
code, response, err := client.SendCustomCommand("allo 1000") code, response, err := client.SendCommand("allo 1000")
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, ftp.StatusCommandOK, code) assert.Equal(t, ftp.StatusCommandOK, code)
assert.Equal(t, "Done !", response) assert.Equal(t, "Done !", response)
@ -2894,7 +2894,7 @@ func TestAllocateAvailable(t *testing.T) {
err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0)
assert.NoError(t, err) assert.NoError(t, err)
code, response, err = client.SendCustomCommand("AVBL") code, response, err = client.SendCommand("AVBL")
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, ftp.StatusFile, code) assert.Equal(t, ftp.StatusFile, code)
assert.Equal(t, "1", response) assert.Equal(t, "1", response)
@ -2909,7 +2909,7 @@ func TestAllocateAvailable(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
client, err = getFTPClient(user, false, nil) client, err = getFTPClient(user, false, nil)
if assert.NoError(t, err) { if assert.NoError(t, err) {
code, response, err := client.SendCustomCommand("AVBL") code, response, err := client.SendCommand("AVBL")
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, ftp.StatusFile, code) assert.Equal(t, ftp.StatusFile, code)
assert.Equal(t, "1", response) assert.Equal(t, "1", response)
@ -2925,7 +2925,7 @@ func TestAllocateAvailable(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
client, err = getFTPClient(user, false, nil) client, err = getFTPClient(user, false, nil)
if assert.NoError(t, err) { if assert.NoError(t, err) {
code, response, err := client.SendCustomCommand("AVBL") code, response, err := client.SendCommand("AVBL")
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, ftp.StatusFile, code) assert.Equal(t, ftp.StatusFile, code)
assert.Equal(t, "5242880", response) assert.Equal(t, "5242880", response)
@ -2941,7 +2941,7 @@ func TestAllocateAvailable(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
client, err = getFTPClient(user, false, nil) client, err = getFTPClient(user, false, nil)
if assert.NoError(t, err) { if assert.NoError(t, err) {
code, response, err := client.SendCustomCommand("AVBL") code, response, err := client.SendCommand("AVBL")
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, ftp.StatusFile, code) assert.Equal(t, ftp.StatusFile, code)
assert.Equal(t, "5242880", response) assert.Equal(t, "5242880", response)
@ -2958,12 +2958,12 @@ func TestAllocateAvailable(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
client, err = getFTPClient(user, false, nil) client, err = getFTPClient(user, false, nil)
if assert.NoError(t, err) { if assert.NoError(t, err) {
code, response, err := client.SendCustomCommand("allo 10000") code, response, err := client.SendCommand("allo 10000")
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, ftp.StatusCommandOK, code) assert.Equal(t, ftp.StatusCommandOK, code)
assert.Equal(t, "Done !", response) assert.Equal(t, "Done !", response)
code, response, err = client.SendCustomCommand("AVBL") code, response, err = client.SendCommand("AVBL")
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, ftp.StatusFile, code) assert.Equal(t, ftp.StatusFile, code)
assert.Equal(t, "100", response) assert.Equal(t, "100", response)
@ -2977,7 +2977,7 @@ func TestAllocateAvailable(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
client, err = getFTPClient(user, false, nil) client, err = getFTPClient(user, false, nil)
if assert.NoError(t, err) { if assert.NoError(t, err) {
code, response, err := client.SendCustomCommand("AVBL") code, response, err := client.SendCommand("AVBL")
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, ftp.StatusFile, code) assert.Equal(t, ftp.StatusFile, code)
assert.Equal(t, "0", response) assert.Equal(t, "0", response)
@ -2989,7 +2989,7 @@ func TestAllocateAvailable(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
client, err = getFTPClient(user, false, nil) client, err = getFTPClient(user, false, nil)
if assert.NoError(t, err) { if assert.NoError(t, err) {
code, response, err := client.SendCustomCommand("AVBL") code, response, err := client.SendCommand("AVBL")
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, ftp.StatusFile, code) assert.Equal(t, ftp.StatusFile, code)
assert.Equal(t, "1", response) assert.Equal(t, "1", response)
@ -3013,7 +3013,7 @@ func TestAvailableSFTPFs(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
client, err := getFTPClient(sftpUser, false, nil) client, err := getFTPClient(sftpUser, false, nil)
if assert.NoError(t, err) { if assert.NoError(t, err) {
code, response, err := client.SendCustomCommand("AVBL /") code, response, err := client.SendCommand("AVBL /")
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, ftp.StatusFile, code) assert.Equal(t, ftp.StatusFile, code)
avblSize, err := strconv.ParseInt(response, 10, 64) avblSize, err := strconv.ParseInt(response, 10, 64)
@ -3051,7 +3051,7 @@ func TestChtimes(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
mtime := time.Now().Format("20060102150405") mtime := time.Now().Format("20060102150405")
code, response, err := client.SendCustomCommand(fmt.Sprintf("MFMT %v %v", mtime, testFileName)) code, response, err := client.SendCommand(fmt.Sprintf("MFMT %v %v", mtime, testFileName))
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, ftp.StatusFile, code) assert.Equal(t, ftp.StatusFile, code)
assert.Equal(t, fmt.Sprintf("Modify=%v; %v", mtime, testFileName), response) assert.Equal(t, fmt.Sprintf("Modify=%v; %v", mtime, testFileName), response)
@ -3097,7 +3097,7 @@ func TestChown(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0)
assert.NoError(t, err) assert.NoError(t, err)
code, response, err := client.SendCustomCommand(fmt.Sprintf("SITE CHOWN 1000:1000 %v", testFileName)) code, response, err := client.SendCommand(fmt.Sprintf("SITE CHOWN 1000:1000 %v", testFileName))
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, ftp.StatusFileUnavailable, code) assert.Equal(t, ftp.StatusFileUnavailable, code)
assert.Equal(t, "Couldn't chown: operation unsupported", response) assert.Equal(t, "Couldn't chown: operation unsupported", response)
@ -3135,7 +3135,7 @@ func TestChmod(t *testing.T) {
err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0)
assert.NoError(t, err) assert.NoError(t, err)
code, response, err := client.SendCustomCommand(fmt.Sprintf("SITE CHMOD 600 %v", testFileName)) code, response, err := client.SendCommand(fmt.Sprintf("SITE CHMOD 600 %v", testFileName))
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, ftp.StatusCommandOK, code) assert.Equal(t, ftp.StatusCommandOK, code)
assert.Equal(t, "SITE CHMOD command successful", response) assert.Equal(t, "SITE CHMOD command successful", response)
@ -3182,7 +3182,7 @@ func TestCombineDisabled(t *testing.T) {
err = checkBasicFTP(client) err = checkBasicFTP(client)
assert.NoError(t, err) assert.NoError(t, err)
code, response, err := client.SendCustomCommand("COMB file file.1 file.2") code, response, err := client.SendCommand("COMB file file.1 file.2")
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, ftp.StatusNotImplemented, code) assert.Equal(t, ftp.StatusNotImplemented, code)
assert.Equal(t, "COMB support is disabled", response) assert.Equal(t, "COMB support is disabled", response)
@ -3208,12 +3208,12 @@ func TestActiveModeDisabled(t *testing.T) {
if assert.NoError(t, err) { if assert.NoError(t, err) {
err = checkBasicFTP(client) err = checkBasicFTP(client)
assert.NoError(t, err) assert.NoError(t, err)
code, response, err := client.SendCustomCommand("PORT 10,2,0,2,4,31") code, response, err := client.SendCommand("PORT 10,2,0,2,4,31")
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, ftp.StatusNotAvailable, code) assert.Equal(t, ftp.StatusNotAvailable, code)
assert.Equal(t, "PORT command is disabled", response) assert.Equal(t, "PORT command is disabled", response)
code, response, err = client.SendCustomCommand("EPRT |1|132.235.1.2|6275|") code, response, err = client.SendCommand("EPRT |1|132.235.1.2|6275|")
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, ftp.StatusNotAvailable, code) assert.Equal(t, ftp.StatusNotAvailable, code)
assert.Equal(t, "EPRT command is disabled", response) assert.Equal(t, "EPRT command is disabled", response)
@ -3224,12 +3224,12 @@ func TestActiveModeDisabled(t *testing.T) {
client, err = getFTPClient(user, false, nil) client, err = getFTPClient(user, false, nil)
if assert.NoError(t, err) { if assert.NoError(t, err) {
code, response, err := client.SendCustomCommand("PORT 10,2,0,2,4,31") code, response, err := client.SendCommand("PORT 10,2,0,2,4,31")
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, ftp.StatusBadArguments, code) assert.Equal(t, ftp.StatusBadArguments, code)
assert.Equal(t, "Your request does not meet the configured security requirements", response) assert.Equal(t, "Your request does not meet the configured security requirements", response)
code, response, err = client.SendCustomCommand("EPRT |1|132.235.1.2|6275|") code, response, err = client.SendCommand("EPRT |1|132.235.1.2|6275|")
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, ftp.StatusBadArguments, code) assert.Equal(t, ftp.StatusBadArguments, code)
assert.Equal(t, "Your request does not meet the configured security requirements", response) assert.Equal(t, "Your request does not meet the configured security requirements", response)
@ -3253,7 +3253,7 @@ func TestSITEDisabled(t *testing.T) {
err = checkBasicFTP(client) err = checkBasicFTP(client)
assert.NoError(t, err) assert.NoError(t, err)
code, response, err := client.SendCustomCommand("SITE CHMOD 600 afile.txt") code, response, err := client.SendCommand("SITE CHMOD 600 afile.txt")
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, ftp.StatusBadCommand, code) assert.Equal(t, ftp.StatusBadCommand, code)
assert.Equal(t, "SITE support is disabled", response) assert.Equal(t, "SITE support is disabled", response)
@ -3298,12 +3298,12 @@ func TestHASH(t *testing.T) {
err = f.Close() err = f.Close()
assert.NoError(t, err) assert.NoError(t, err)
code, response, err := client.SendCustomCommand(fmt.Sprintf("XSHA256 %v", testFileName)) code, response, err := client.SendCommand(fmt.Sprintf("XSHA256 %v", testFileName))
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, ftp.StatusRequestedFileActionOK, code) assert.Equal(t, ftp.StatusRequestedFileActionOK, code)
assert.Contains(t, response, hash) assert.Contains(t, response, hash)
code, response, err = client.SendCustomCommand(fmt.Sprintf("HASH %v", testFileName)) code, response, err = client.SendCommand(fmt.Sprintf("HASH %v", testFileName))
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, ftp.StatusFile, code) assert.Equal(t, ftp.StatusFile, code)
assert.Contains(t, response, hash) assert.Contains(t, response, hash)
@ -3359,7 +3359,7 @@ func TestCombine(t *testing.T) {
err = ftpUploadFile(testFilePath, testFileName+".2", testFileSize, client, 0) err = ftpUploadFile(testFilePath, testFileName+".2", testFileSize, client, 0)
assert.NoError(t, err) assert.NoError(t, err)
code, response, err := client.SendCustomCommand(fmt.Sprintf("COMB %v %v %v", testFileName, testFileName+".1", testFileName+".2")) code, response, err := client.SendCommand(fmt.Sprintf("COMB %v %v %v", testFileName, testFileName+".1", testFileName+".2"))
assert.NoError(t, err) assert.NoError(t, err)
if user.Username == defaultUsername { if user.Username == defaultUsername {
assert.Equal(t, ftp.StatusRequestedFileActionOK, code) assert.Equal(t, ftp.StatusRequestedFileActionOK, code)

View file

@ -291,7 +291,7 @@ func (c *Connection) Symlink(oldname, newname string) error {
} }
// ReadDir implements ClientDriverExtensionFilelist // ReadDir implements ClientDriverExtensionFilelist
func (c *Connection) ReadDir(name string) ([]os.FileInfo, error) { func (c *Connection) ReadDir(name string) (ftpserver.DirLister, error) {
c.UpdateLastActivity() c.UpdateLastActivity()
if c.doWildcardListDir { if c.doWildcardListDir {
@ -302,7 +302,17 @@ func (c *Connection) ReadDir(name string) ([]os.FileInfo, error) {
// - dir*/*.xml is not supported // - dir*/*.xml is not supported
name = path.Dir(name) name = path.Dir(name)
c.clientContext.SetListPath(name) c.clientContext.SetListPath(name)
return c.getListDirWithWildcards(name, baseName) lister, err := c.ListDir(name)
if err != nil {
return nil, err
}
return &patternDirLister{
DirLister: lister,
pattern: baseName,
lastCommand: c.clientContext.GetLastCommand(),
dirName: name,
connectionPath: c.clientContext.Path(),
}, nil
} }
return c.ListDir(name) return c.ListDir(name)
@ -506,31 +516,6 @@ func (c *Connection) handleFTPUploadToExistingFile(fs vfs.Fs, flags int, resolve
return t, nil return t, nil
} }
func (c *Connection) getListDirWithWildcards(dirName, pattern string) ([]os.FileInfo, error) {
files, err := c.ListDir(dirName)
if err != nil {
return files, err
}
validIdx := 0
var relativeBase string
if c.clientContext.GetLastCommand() != "NLST" {
relativeBase = getPathRelativeTo(c.clientContext.Path(), dirName)
}
for _, fi := range files {
match, err := path.Match(pattern, fi.Name())
if err != nil {
return files, err
}
if match {
files[validIdx] = vfs.NewFileInfo(path.Join(relativeBase, fi.Name()), fi.IsDir(), fi.Size(),
fi.ModTime(), true)
validIdx++
}
}
return files[:validIdx], nil
}
func (c *Connection) isListDirWithWildcards(name string) bool { func (c *Connection) isListDirWithWildcards(name string) bool {
if strings.ContainsAny(name, "*?[]^") { if strings.ContainsAny(name, "*?[]^") {
lastCommand := c.clientContext.GetLastCommand() lastCommand := c.clientContext.GetLastCommand()
@ -559,3 +544,40 @@ func getPathRelativeTo(base, target string) string {
base = path.Dir(path.Clean(base)) base = path.Dir(path.Clean(base))
} }
} }
type patternDirLister struct {
vfs.DirLister
pattern string
lastCommand string
dirName string
connectionPath string
}
func (l *patternDirLister) Next(limit int) ([]os.FileInfo, error) {
for {
files, err := l.DirLister.Next(limit)
if len(files) == 0 {
return files, err
}
validIdx := 0
var relativeBase string
if l.lastCommand != "NLST" {
relativeBase = getPathRelativeTo(l.connectionPath, l.dirName)
}
for _, fi := range files {
match, errMatch := path.Match(l.pattern, fi.Name())
if errMatch != nil {
return nil, errMatch
}
if match {
files[validIdx] = vfs.NewFileInfo(path.Join(relativeBase, fi.Name()), fi.IsDir(), fi.Size(),
fi.ModTime(), true)
validIdx++
}
}
files = files[:validIdx]
if err != nil || len(files) > 0 {
return files, err
}
}
}

View file

@ -74,12 +74,12 @@ func readUserFolder(w http.ResponseWriter, r *http.Request) {
defer common.Connections.Remove(connection.GetID()) defer common.Connections.Remove(connection.GetID())
name := connection.User.GetCleanedPath(r.URL.Query().Get("path")) name := connection.User.GetCleanedPath(r.URL.Query().Get("path"))
contents, err := connection.ReadDir(name) lister, err := connection.ReadDir(name)
if err != nil { if err != nil {
sendAPIResponse(w, r, err, "Unable to get directory contents", getMappedStatusCode(err)) sendAPIResponse(w, r, err, "Unable to get directory lister", getMappedStatusCode(err))
return return
} }
renderAPIDirContents(w, r, contents, false) renderAPIDirContents(w, lister, false)
} }
func createUserDir(w http.ResponseWriter, r *http.Request) { func createUserDir(w http.ResponseWriter, r *http.Request) {

View file

@ -213,12 +213,12 @@ func (s *httpdServer) readBrowsableShareContents(w http.ResponseWriter, r *http.
} }
defer common.Connections.Remove(connection.GetID()) defer common.Connections.Remove(connection.GetID())
contents, err := connection.ReadDir(name) lister, err := connection.ReadDir(name)
if err != nil { if err != nil {
sendAPIResponse(w, r, err, "Unable to get directory contents", getMappedStatusCode(err)) sendAPIResponse(w, r, err, "Unable to get directory lister", getMappedStatusCode(err))
return return
} }
renderAPIDirContents(w, r, contents, true) renderAPIDirContents(w, lister, true)
} }
func (s *httpdServer) downloadBrowsableSharedFile(w http.ResponseWriter, r *http.Request) { func (s *httpdServer) downloadBrowsableSharedFile(w http.ResponseWriter, r *http.Request) {

View file

@ -17,6 +17,7 @@ package httpd
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@ -44,6 +45,7 @@ import (
"github.com/drakkan/sftpgo/v2/internal/plugin" "github.com/drakkan/sftpgo/v2/internal/plugin"
"github.com/drakkan/sftpgo/v2/internal/smtp" "github.com/drakkan/sftpgo/v2/internal/smtp"
"github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/util"
"github.com/drakkan/sftpgo/v2/internal/vfs"
) )
type pwdChange struct { type pwdChange struct {
@ -280,23 +282,40 @@ func getSearchFilters(w http.ResponseWriter, r *http.Request) (int, int, string,
return limit, offset, order, err return limit, offset, order, err
} }
func renderAPIDirContents(w http.ResponseWriter, r *http.Request, contents []os.FileInfo, omitNonRegularFiles bool) { func renderAPIDirContents(w http.ResponseWriter, lister vfs.DirLister, omitNonRegularFiles bool) {
results := make([]map[string]any, 0, len(contents)) defer lister.Close()
for _, info := range contents {
if omitNonRegularFiles && !info.Mode().IsDir() && !info.Mode().IsRegular() { dataGetter := func(limit, _ int) ([]byte, int, error) {
continue contents, err := lister.Next(limit)
if errors.Is(err, io.EOF) {
err = nil
} }
res := make(map[string]any) if err != nil {
res["name"] = info.Name() return nil, 0, err
if info.Mode().IsRegular() {
res["size"] = info.Size()
} }
res["mode"] = info.Mode() results := make([]map[string]any, 0, len(contents))
res["last_modified"] = info.ModTime().UTC().Format(time.RFC3339) for _, info := range contents {
results = append(results, res) if omitNonRegularFiles && !info.Mode().IsDir() && !info.Mode().IsRegular() {
continue
}
res := make(map[string]any)
res["name"] = info.Name()
if info.Mode().IsRegular() {
res["size"] = info.Size()
}
res["mode"] = info.Mode()
res["last_modified"] = info.ModTime().UTC().Format(time.RFC3339)
results = append(results, res)
}
data, err := json.Marshal(results)
count := limit
if len(results) == 0 {
count = 0
}
return data, count, err
} }
render.JSON(w, r, results) streamJSONArray(w, defaultQueryLimit, dataGetter)
} }
func streamData(w io.Writer, data []byte) { func streamData(w io.Writer, data []byte) {
@ -355,7 +374,7 @@ func renderCompressedFiles(w http.ResponseWriter, conn *Connection, baseDir stri
for _, file := range files { for _, file := range files {
fullPath := util.CleanPath(path.Join(baseDir, file)) fullPath := util.CleanPath(path.Join(baseDir, file))
if err := addZipEntry(wr, conn, fullPath, baseDir); err != nil { if err := addZipEntry(wr, conn, fullPath, baseDir, 0); err != nil {
if share != nil { if share != nil {
dataprovider.UpdateShareLastUse(share, -1) //nolint:errcheck dataprovider.UpdateShareLastUse(share, -1) //nolint:errcheck
} }
@ -371,7 +390,12 @@ func renderCompressedFiles(w http.ResponseWriter, conn *Connection, baseDir stri
} }
} }
func addZipEntry(wr *zip.Writer, conn *Connection, entryPath, baseDir string) error { func addZipEntry(wr *zip.Writer, conn *Connection, entryPath, baseDir string, recursion int) error {
if recursion >= util.MaxRecursion {
conn.Log(logger.LevelDebug, "unable to add zip entry %q, recursion too depth: %d", entryPath, recursion)
return util.ErrRecursionTooDeep
}
recursion++
info, err := conn.Stat(entryPath, 1) info, err := conn.Stat(entryPath, 1)
if err != nil { if err != nil {
conn.Log(logger.LevelDebug, "unable to add zip entry %q, stat error: %v", entryPath, err) conn.Log(logger.LevelDebug, "unable to add zip entry %q, stat error: %v", entryPath, err)
@ -392,24 +416,39 @@ func addZipEntry(wr *zip.Writer, conn *Connection, entryPath, baseDir string) er
conn.Log(logger.LevelError, "unable to create zip entry %q: %v", entryPath, err) conn.Log(logger.LevelError, "unable to create zip entry %q: %v", entryPath, err)
return err return err
} }
contents, err := conn.ReadDir(entryPath) lister, err := conn.ReadDir(entryPath)
if err != nil { if err != nil {
conn.Log(logger.LevelDebug, "unable to add zip entry %q, read dir error: %v", entryPath, err) conn.Log(logger.LevelDebug, "unable to add zip entry %q, get list dir error: %v", entryPath, err)
return err return err
} }
for _, info := range contents { defer lister.Close()
fullPath := util.CleanPath(path.Join(entryPath, info.Name()))
if err := addZipEntry(wr, conn, fullPath, baseDir); err != nil { for {
contents, err := lister.Next(vfs.ListerBatchSize)
finished := errors.Is(err, io.EOF)
if err != nil && !finished {
return err return err
} }
for _, info := range contents {
fullPath := util.CleanPath(path.Join(entryPath, info.Name()))
if err := addZipEntry(wr, conn, fullPath, baseDir, recursion); err != nil {
return err
}
}
if finished {
return nil
}
} }
return nil
} }
if !info.Mode().IsRegular() { if !info.Mode().IsRegular() {
// we only allow regular files // we only allow regular files
conn.Log(logger.LevelInfo, "skipping zip entry for non regular file %q", entryPath) conn.Log(logger.LevelInfo, "skipping zip entry for non regular file %q", entryPath)
return nil return nil
} }
return addFileToZipEntry(wr, conn, entryPath, entryName, info)
}
func addFileToZipEntry(wr *zip.Writer, conn *Connection, entryPath, entryName string, info os.FileInfo) error {
reader, err := conn.getFileReader(entryPath, 0, http.MethodGet) reader, err := conn.getFileReader(entryPath, 0, http.MethodGet)
if err != nil { if err != nil {
conn.Log(logger.LevelDebug, "unable to add zip entry %q, cannot open file: %v", entryPath, err) conn.Log(logger.LevelDebug, "unable to add zip entry %q, cannot open file: %v", entryPath, err)

View file

@ -88,7 +88,7 @@ func (c *Connection) Stat(name string, mode int) (os.FileInfo, error) {
} }
// ReadDir returns a list of directory entries // ReadDir returns a list of directory entries
func (c *Connection) ReadDir(name string) ([]os.FileInfo, error) { func (c *Connection) ReadDir(name string) (vfs.DirLister, error) {
c.UpdateLastActivity() c.UpdateLastActivity()
return c.ListDir(name) return c.ListDir(name)

View file

@ -16028,7 +16028,7 @@ func TestWebGetFiles(t *testing.T) {
setBearerForReq(req, webAPIToken) setBearerForReq(req, webAPIToken)
rr = executeRequest(req) rr = executeRequest(req)
checkResponseCode(t, http.StatusNotFound, rr) checkResponseCode(t, http.StatusNotFound, rr)
assert.Contains(t, rr.Body.String(), "Unable to get directory contents") assert.Contains(t, rr.Body.String(), "Unable to get directory lister")
req, _ = http.NewRequest(http.MethodGet, webClientFilesPath+"?path="+testFileName, nil) req, _ = http.NewRequest(http.MethodGet, webClientFilesPath+"?path="+testFileName, nil)
setJWTCookieForReq(req, webToken) setJWTCookieForReq(req, webToken)

View file

@ -2196,7 +2196,7 @@ func TestRecoverer(t *testing.T) {
} }
func TestStreamJSONArray(t *testing.T) { func TestStreamJSONArray(t *testing.T) {
dataGetter := func(limit, offset int) ([]byte, int, error) { dataGetter := func(_, _ int) ([]byte, int, error) {
return nil, 0, nil return nil, 0, nil
} }
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
@ -2268,12 +2268,14 @@ func TestZipErrors(t *testing.T) {
assert.Contains(t, err.Error(), "write error") assert.Contains(t, err.Error(), "write error")
} }
err = addZipEntry(wr, connection, "/"+filepath.Base(testDir), "/") err = addZipEntry(wr, connection, "/"+filepath.Base(testDir), "/", 0)
if assert.Error(t, err) { if assert.Error(t, err) {
assert.Contains(t, err.Error(), "write error") assert.Contains(t, err.Error(), "write error")
} }
err = addZipEntry(wr, connection, "/"+filepath.Base(testDir), "/", 2000)
assert.ErrorIs(t, err, util.ErrRecursionTooDeep)
err = addZipEntry(wr, connection, "/"+filepath.Base(testDir), path.Join("/", filepath.Base(testDir), "dir")) err = addZipEntry(wr, connection, "/"+filepath.Base(testDir), path.Join("/", filepath.Base(testDir), "dir"), 0)
if assert.Error(t, err) { if assert.Error(t, err) {
assert.Contains(t, err.Error(), "is outside base dir") assert.Contains(t, err.Error(), "is outside base dir")
} }
@ -2282,14 +2284,14 @@ func TestZipErrors(t *testing.T) {
err = os.WriteFile(testFilePath, util.GenerateRandomBytes(65535), os.ModePerm) err = os.WriteFile(testFilePath, util.GenerateRandomBytes(65535), os.ModePerm)
assert.NoError(t, err) assert.NoError(t, err)
err = addZipEntry(wr, connection, path.Join("/", filepath.Base(testDir), filepath.Base(testFilePath)), err = addZipEntry(wr, connection, path.Join("/", filepath.Base(testDir), filepath.Base(testFilePath)),
"/"+filepath.Base(testDir)) "/"+filepath.Base(testDir), 0)
if assert.Error(t, err) { if assert.Error(t, err) {
assert.Contains(t, err.Error(), "write error") assert.Contains(t, err.Error(), "write error")
} }
connection.User.Permissions["/"] = []string{dataprovider.PermListItems} connection.User.Permissions["/"] = []string{dataprovider.PermListItems}
err = addZipEntry(wr, connection, path.Join("/", filepath.Base(testDir), filepath.Base(testFilePath)), err = addZipEntry(wr, connection, path.Join("/", filepath.Base(testDir), filepath.Base(testFilePath)),
"/"+filepath.Base(testDir)) "/"+filepath.Base(testDir), 0)
assert.ErrorIs(t, err, os.ErrPermission) assert.ErrorIs(t, err, os.ErrPermission)
// creating a virtual folder to a missing path stat is ok but readdir fails // creating a virtual folder to a missing path stat is ok but readdir fails
@ -2301,14 +2303,14 @@ func TestZipErrors(t *testing.T) {
}) })
connection.User = user connection.User = user
wr = zip.NewWriter(bytes.NewBuffer(make([]byte, 0))) wr = zip.NewWriter(bytes.NewBuffer(make([]byte, 0)))
err = addZipEntry(wr, connection, user.VirtualFolders[0].VirtualPath, "/") err = addZipEntry(wr, connection, user.VirtualFolders[0].VirtualPath, "/", 0)
assert.Error(t, err) assert.Error(t, err)
user.Filters.FilePatterns = append(user.Filters.FilePatterns, sdk.PatternsFilter{ user.Filters.FilePatterns = append(user.Filters.FilePatterns, sdk.PatternsFilter{
Path: "/", Path: "/",
DeniedPatterns: []string{"*.zip"}, DeniedPatterns: []string{"*.zip"},
}) })
err = addZipEntry(wr, connection, "/"+filepath.Base(testDir), "/") err = addZipEntry(wr, connection, "/"+filepath.Base(testDir), "/", 0)
assert.ErrorIs(t, err, os.ErrPermission) assert.ErrorIs(t, err, os.ErrPermission)
err = os.RemoveAll(testDir) err = os.RemoveAll(testDir)

View file

@ -958,33 +958,50 @@ func (s *httpdServer) handleShareGetDirContents(w http.ResponseWriter, r *http.R
} }
defer common.Connections.Remove(connection.GetID()) defer common.Connections.Remove(connection.GetID())
contents, err := connection.ReadDir(name) lister, err := connection.ReadDir(name)
if err != nil { if err != nil {
sendAPIResponse(w, r, err, getI18NErrorString(err, util.I18nErrorDirListGeneric), getMappedStatusCode(err)) sendAPIResponse(w, r, err, getI18NErrorString(err, util.I18nErrorDirListGeneric), getMappedStatusCode(err))
return return
} }
results := make([]map[string]any, 0, len(contents)) defer lister.Close()
for _, info := range contents {
if !info.Mode().IsDir() && !info.Mode().IsRegular() { dataGetter := func(limit, _ int) ([]byte, int, error) {
continue contents, err := lister.Next(limit)
if errors.Is(err, io.EOF) {
err = nil
} }
res := make(map[string]any) if err != nil {
if info.IsDir() { return nil, 0, err
res["type"] = "1"
res["size"] = ""
} else {
res["type"] = "2"
res["size"] = info.Size()
} }
res["meta"] = fmt.Sprintf("%v_%v", res["type"], info.Name()) results := make([]map[string]any, 0, len(contents))
res["name"] = info.Name() for _, info := range contents {
res["url"] = getFileObjectURL(share.GetRelativePath(name), info.Name(), if !info.Mode().IsDir() && !info.Mode().IsRegular() {
path.Join(webClientPubSharesPath, share.ShareID, "browse")) continue
res["last_modified"] = getFileObjectModTime(info.ModTime()) }
results = append(results, res) res := make(map[string]any)
if info.IsDir() {
res["type"] = "1"
res["size"] = ""
} else {
res["type"] = "2"
res["size"] = info.Size()
}
res["meta"] = fmt.Sprintf("%v_%v", res["type"], info.Name())
res["name"] = info.Name()
res["url"] = getFileObjectURL(share.GetRelativePath(name), info.Name(),
path.Join(webClientPubSharesPath, share.ShareID, "browse"))
res["last_modified"] = getFileObjectModTime(info.ModTime())
results = append(results, res)
}
data, err := json.Marshal(results)
count := limit
if len(results) == 0 {
count = 0
}
return data, count, err
} }
render.JSON(w, r, results) streamJSONArray(w, defaultQueryLimit, dataGetter)
} }
func (s *httpdServer) handleClientUploadToShare(w http.ResponseWriter, r *http.Request) { func (s *httpdServer) handleClientUploadToShare(w http.ResponseWriter, r *http.Request) {
@ -1146,43 +1163,59 @@ func (s *httpdServer) handleClientGetDirContents(w http.ResponseWriter, r *http.
defer common.Connections.Remove(connection.GetID()) defer common.Connections.Remove(connection.GetID())
name := connection.User.GetCleanedPath(r.URL.Query().Get("path")) name := connection.User.GetCleanedPath(r.URL.Query().Get("path"))
contents, err := connection.ReadDir(name) lister, err := connection.ReadDir(name)
if err != nil { if err != nil {
statusCode := getMappedStatusCode(err) statusCode := getMappedStatusCode(err)
sendAPIResponse(w, r, err, i18nListDirMsg(statusCode), statusCode) sendAPIResponse(w, r, err, i18nListDirMsg(statusCode), statusCode)
return return
} }
defer lister.Close()
dirTree := r.URL.Query().Get("dirtree") == "1" dirTree := r.URL.Query().Get("dirtree") == "1"
results := make([]map[string]any, 0, len(contents)) dataGetter := func(limit, _ int) ([]byte, int, error) {
for _, info := range contents { contents, err := lister.Next(limit)
res := make(map[string]any) if errors.Is(err, io.EOF) {
res["url"] = getFileObjectURL(name, info.Name(), webClientFilesPath) err = nil
if info.IsDir() { }
res["type"] = "1" if err != nil {
res["size"] = "" return nil, 0, err
res["dir_path"] = url.QueryEscape(path.Join(name, info.Name())) }
} else { results := make([]map[string]any, 0, len(contents))
if dirTree { for _, info := range contents {
continue res := make(map[string]any)
} res["url"] = getFileObjectURL(name, info.Name(), webClientFilesPath)
res["type"] = "2" if info.IsDir() {
if info.Mode()&os.ModeSymlink != 0 { res["type"] = "1"
res["size"] = "" res["size"] = ""
res["dir_path"] = url.QueryEscape(path.Join(name, info.Name()))
} else { } else {
res["size"] = info.Size() if dirTree {
if info.Size() < httpdMaxEditFileSize { continue
res["edit_url"] = strings.Replace(res["url"].(string), webClientFilesPath, webClientEditFilePath, 1) }
res["type"] = "2"
if info.Mode()&os.ModeSymlink != 0 {
res["size"] = ""
} else {
res["size"] = info.Size()
if info.Size() < httpdMaxEditFileSize {
res["edit_url"] = strings.Replace(res["url"].(string), webClientFilesPath, webClientEditFilePath, 1)
}
} }
} }
res["meta"] = fmt.Sprintf("%v_%v", res["type"], info.Name())
res["name"] = info.Name()
res["last_modified"] = getFileObjectModTime(info.ModTime())
results = append(results, res)
} }
res["meta"] = fmt.Sprintf("%v_%v", res["type"], info.Name()) data, err := json.Marshal(results)
res["name"] = info.Name() count := limit
res["last_modified"] = getFileObjectModTime(info.ModTime()) if len(results) == 0 {
results = append(results, res) count = 0
}
return data, count, err
} }
render.JSON(w, r, results) streamJSONArray(w, defaultQueryLimit, dataGetter)
} }
func (s *httpdServer) handleClientGetFiles(w http.ResponseWriter, r *http.Request) { func (s *httpdServer) handleClientGetFiles(w http.ResponseWriter, r *http.Request) {
@ -1917,27 +1950,45 @@ func doCheckExist(w http.ResponseWriter, r *http.Request, connection *Connection
return return
} }
contents, err := connection.ListDir(name) lister, err := connection.ListDir(name)
if err != nil { if err != nil {
sendAPIResponse(w, r, err, "Unable to get directory contents", getMappedStatusCode(err)) sendAPIResponse(w, r, err, "Unable to get directory contents", getMappedStatusCode(err))
return return
} }
existing := make([]map[string]any, 0) defer lister.Close()
for _, info := range contents {
if util.Contains(filesList.Files, info.Name()) { dataGetter := func(limit, _ int) ([]byte, int, error) {
res := make(map[string]any) contents, err := lister.Next(limit)
res["name"] = info.Name() if errors.Is(err, io.EOF) {
if info.IsDir() { err = nil
res["type"] = "1"
res["size"] = ""
} else {
res["type"] = "2"
res["size"] = info.Size()
}
existing = append(existing, res)
} }
if err != nil {
return nil, 0, err
}
existing := make([]map[string]any, 0)
for _, info := range contents {
if util.Contains(filesList.Files, info.Name()) {
res := make(map[string]any)
res["name"] = info.Name()
if info.IsDir() {
res["type"] = "1"
res["size"] = ""
} else {
res["type"] = "2"
res["size"] = info.Size()
}
existing = append(existing, res)
}
}
data, err := json.Marshal(existing)
count := limit
if len(existing) == 0 {
count = 0
}
return data, count, err
} }
render.JSON(w, r, existing)
streamJSONArray(w, defaultQueryLimit, dataGetter)
} }
func checkShareRedirectURL(next, base string) (bool, string) { func checkShareRedirectURL(next, base string) (bool, string) {

View file

@ -91,8 +91,8 @@ func (s *Service) initLogger() {
// Start initializes and starts the service // Start initializes and starts the service
func (s *Service) Start(disableAWSInstallationCode bool) error { func (s *Service) Start(disableAWSInstallationCode bool) error {
s.initLogger() s.initLogger()
logger.Info(logSender, "", "starting SFTPGo %v, config dir: %v, config file: %v, log max size: %v log max backups: %v "+ logger.Info(logSender, "", "starting SFTPGo %s, config dir: %s, config file: %s, log max size: %d log max backups: %d "+
"log max age: %v log level: %v, log compress: %v, log utc time: %v, load data from: %q, grace time: %d secs", "log max age: %d log level: %s, log compress: %t, log utc time: %t, load data from: %q, grace time: %d secs",
version.GetAsString(), s.ConfigDir, s.ConfigFile, s.LogMaxSize, s.LogMaxBackups, s.LogMaxAge, s.LogLevel, version.GetAsString(), s.ConfigDir, s.ConfigFile, s.LogMaxSize, s.LogMaxBackups, s.LogMaxAge, s.LogLevel,
s.LogCompress, s.LogUTCTime, s.LoadDataFrom, graceTime) s.LogCompress, s.LogUTCTime, s.LoadDataFrom, graceTime)
// in portable mode we don't read configuration from file // in portable mode we don't read configuration from file

View file

@ -216,16 +216,16 @@ func (c *Connection) Filelist(request *sftp.Request) (sftp.ListerAt, error) {
switch request.Method { switch request.Method {
case "List": case "List":
files, err := c.ListDir(request.Filepath) lister, err := c.ListDir(request.Filepath)
if err != nil { if err != nil {
return nil, err return nil, err
} }
modTime := time.Unix(0, 0) modTime := time.Unix(0, 0)
if request.Filepath != "/" || c.folderPrefix != "" { if request.Filepath != "/" || c.folderPrefix != "" {
files = util.PrependFileInfo(files, vfs.NewFileInfo("..", true, 0, modTime, false)) lister.Add(vfs.NewFileInfo("..", true, 0, modTime, false))
} }
files = util.PrependFileInfo(files, vfs.NewFileInfo(".", true, 0, modTime, false)) lister.Add(vfs.NewFileInfo(".", true, 0, modTime, false))
return listerAt(files), nil return lister, nil
case "Stat": case "Stat":
if !c.User.HasPerm(dataprovider.PermListItems, path.Dir(request.Filepath)) { if !c.User.HasPerm(dataprovider.PermListItems, path.Dir(request.Filepath)) {
return nil, sftp.ErrSSHFxPermissionDenied return nil, sftp.ErrSSHFxPermissionDenied

View file

@ -1294,6 +1294,17 @@ func TestSCPProtocolMessages(t *testing.T) {
if assert.Error(t, err) { if assert.Error(t, err) {
assert.Equal(t, protocolErrorMsg, err.Error()) assert.Equal(t, protocolErrorMsg, err.Error())
} }
mockSSHChannel = MockChannel{
Buffer: bytes.NewBuffer(respBuffer),
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
ReadError: nil,
WriteError: writeErr,
}
scpCommand.connection.channel = &mockSSHChannel
err = scpCommand.downloadDirs(nil, nil)
assert.ErrorIs(t, err, writeErr)
} }
func TestSCPTestDownloadProtocolMessages(t *testing.T) { func TestSCPTestDownloadProtocolMessages(t *testing.T) {

View file

@ -384,47 +384,65 @@ func (c *scpCommand) handleRecursiveDownload(fs vfs.Fs, dirPath, virtualPath str
if err != nil { if err != nil {
return err return err
} }
files, err := fs.ReadDir(dirPath) // dirPath is a fs path, not a virtual path
lister, err := fs.ReadDir(dirPath)
if err != nil { if err != nil {
c.sendErrorMessage(fs, err) c.sendErrorMessage(fs, err)
return err return err
} }
files = c.connection.User.FilterListDir(files, fs.GetRelativePath(dirPath)) defer lister.Close()
vdirs := c.connection.User.GetVirtualFoldersInfo(virtualPath)
var dirs []string var dirs []string
for _, file := range files { for {
filePath := fs.GetRelativePath(fs.Join(dirPath, file.Name())) files, err := lister.Next(vfs.ListerBatchSize)
if file.Mode().IsRegular() || file.Mode()&os.ModeSymlink != 0 { finished := errors.Is(err, io.EOF)
err = c.handleDownload(filePath) if err != nil && !finished {
if err != nil { c.sendErrorMessage(fs, err)
break return err
}
} else if file.IsDir() {
dirs = append(dirs, filePath)
} }
} files = c.connection.User.FilterListDir(files, fs.GetRelativePath(dirPath))
if err != nil { if len(vdirs) > 0 {
c.sendErrorMessage(fs, err) files = append(files, vdirs...)
return err vdirs = nil
} }
for _, dir := range dirs { for _, file := range files {
err = c.handleDownload(dir) filePath := fs.GetRelativePath(fs.Join(dirPath, file.Name()))
if err != nil { if file.Mode().IsRegular() || file.Mode()&os.ModeSymlink != 0 {
err = c.handleDownload(filePath)
if err != nil {
c.sendErrorMessage(fs, err)
return err
}
} else if file.IsDir() {
dirs = append(dirs, filePath)
}
}
if finished {
break break
} }
} }
if err != nil { lister.Close()
return c.downloadDirs(fs, dirs)
}
err = errors.New("unable to send directory for non recursive copy")
c.sendErrorMessage(nil, err)
return err
}
func (c *scpCommand) downloadDirs(fs vfs.Fs, dirs []string) error {
for _, dir := range dirs {
if err := c.handleDownload(dir); err != nil {
c.sendErrorMessage(fs, err) c.sendErrorMessage(fs, err)
return err return err
} }
err = c.sendProtocolMessage("E\n")
if err != nil {
return err
}
return c.readConfirmationMessage()
} }
err = fmt.Errorf("unable to send directory for non recursive copy") if err := c.sendProtocolMessage("E\n"); err != nil {
c.sendErrorMessage(nil, err) return err
return err }
return c.readConfirmationMessage()
} }
func (c *scpCommand) sendDownloadFileData(fs vfs.Fs, filePath string, stat os.FileInfo, transfer *transfer) error { func (c *scpCommand) sendDownloadFileData(fs vfs.Fs, filePath string, stat os.FileInfo, transfer *transfer) error {

View file

@ -376,6 +376,7 @@ func (c *Configuration) Initialize(configDir string) error {
c.loadModuli(configDir) c.loadModuli(configDir)
sftp.SetSFTPExtensions(sftpExtensions...) //nolint:errcheck // we configure valid SFTP Extensions so we cannot get an error sftp.SetSFTPExtensions(sftpExtensions...) //nolint:errcheck // we configure valid SFTP Extensions so we cannot get an error
sftp.MaxFilelist = vfs.ListerBatchSize
if err := c.configureSecurityOptions(serverConfig); err != nil { if err := c.configureSecurityOptions(serverConfig); err != nil {
return err return err

View file

@ -15,6 +15,7 @@
package util package util
import ( import (
"errors"
"fmt" "fmt"
) )
@ -24,12 +25,16 @@ const (
"sftpgo serve -c \"<path to dir containing the default config file and templates directory>\"" "sftpgo serve -c \"<path to dir containing the default config file and templates directory>\""
) )
// MaxRecursion defines the maximum number of allowed recursions
const MaxRecursion = 1000
// errors definitions // errors definitions
var ( var (
ErrValidation = NewValidationError("") ErrValidation = NewValidationError("")
ErrNotFound = NewRecordNotFoundError("") ErrNotFound = NewRecordNotFoundError("")
ErrMethodDisabled = NewMethodDisabledError("") ErrMethodDisabled = NewMethodDisabledError("")
ErrGeneric = NewGenericError("") ErrGeneric = NewGenericError("")
ErrRecursionTooDeep = errors.New("recursion too deep")
) )
// ValidationError raised if input data is not valid // ValidationError raised if input data is not valid

View file

@ -56,6 +56,10 @@ const (
azFolderKey = "hdi_isfolder" azFolderKey = "hdi_isfolder"
) )
var (
azureBlobDefaultPageSize = int32(5000)
)
// AzureBlobFs is a Fs implementation for Azure Blob storage. // AzureBlobFs is a Fs implementation for Azure Blob storage.
type AzureBlobFs struct { type AzureBlobFs struct {
connectionID string connectionID string
@ -308,7 +312,7 @@ func (fs *AzureBlobFs) Rename(source, target string) (int, int64, error) {
if err != nil { if err != nil {
return -1, -1, err return -1, -1, err
} }
return fs.renameInternal(source, target, fi) return fs.renameInternal(source, target, fi, 0)
} }
// Remove removes the named file or (empty) directory. // Remove removes the named file or (empty) directory.
@ -408,76 +412,23 @@ func (*AzureBlobFs) Truncate(_ string, _ int64) error {
// ReadDir reads the directory named by dirname and returns // ReadDir reads the directory named by dirname and returns
// a list of directory entries. // a list of directory entries.
func (fs *AzureBlobFs) ReadDir(dirname string) ([]os.FileInfo, error) { func (fs *AzureBlobFs) ReadDir(dirname string) (DirLister, error) {
var result []os.FileInfo
// dirname must be already cleaned // dirname must be already cleaned
prefix := fs.getPrefix(dirname) prefix := fs.getPrefix(dirname)
modTimes, err := getFolderModTimes(fs.getStorageID(), dirname)
if err != nil {
return result, err
}
prefixes := make(map[string]bool)
pager := fs.containerClient.NewListBlobsHierarchyPager("/", &container.ListBlobsHierarchyOptions{ pager := fs.containerClient.NewListBlobsHierarchyPager("/", &container.ListBlobsHierarchyOptions{
Include: container.ListBlobsInclude{ Include: container.ListBlobsInclude{
//Metadata: true, //Metadata: true,
}, },
Prefix: &prefix, Prefix: &prefix,
MaxResults: &azureBlobDefaultPageSize,
}) })
for pager.More() { return &azureBlobDirLister{
ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) paginator: pager,
defer cancelFn() timeout: fs.ctxTimeout,
prefix: prefix,
resp, err := pager.NextPage(ctx) prefixes: make(map[string]bool),
if err != nil { }, nil
metric.AZListObjectsCompleted(err)
return result, err
}
for _, blobPrefix := range resp.ListBlobsHierarchySegmentResponse.Segment.BlobPrefixes {
name := util.GetStringFromPointer(blobPrefix.Name)
// we don't support prefixes == "/" this will be sent if a key starts with "/"
if name == "" || name == "/" {
continue
}
// sometime we have duplicate prefixes, maybe an Azurite bug
name = strings.TrimPrefix(name, prefix)
if _, ok := prefixes[strings.TrimSuffix(name, "/")]; ok {
continue
}
result = append(result, NewFileInfo(name, true, 0, time.Unix(0, 0), false))
prefixes[strings.TrimSuffix(name, "/")] = true
}
for _, blobItem := range resp.ListBlobsHierarchySegmentResponse.Segment.BlobItems {
name := util.GetStringFromPointer(blobItem.Name)
name = strings.TrimPrefix(name, prefix)
size := int64(0)
isDir := false
modTime := time.Unix(0, 0)
if blobItem.Properties != nil {
size = util.GetIntFromPointer(blobItem.Properties.ContentLength)
modTime = util.GetTimeFromPointer(blobItem.Properties.LastModified)
contentType := util.GetStringFromPointer(blobItem.Properties.ContentType)
isDir = checkDirectoryMarkers(contentType, blobItem.Metadata)
if isDir {
// check if the dir is already included, it will be sent as blob prefix if it contains at least one item
if _, ok := prefixes[name]; ok {
continue
}
prefixes[name] = true
}
}
if t, ok := modTimes[name]; ok {
modTime = util.GetTimeFromMsecSinceEpoch(t)
}
result = append(result, NewFileInfo(name, isDir, size, modTime, false))
}
}
metric.AZListObjectsCompleted(nil)
return result, nil
} }
// IsUploadResumeSupported returns true if resuming uploads is supported. // IsUploadResumeSupported returns true if resuming uploads is supported.
@ -569,7 +520,8 @@ func (fs *AzureBlobFs) getFileNamesInPrefix(fsPrefix string) (map[string]bool, e
Include: container.ListBlobsInclude{ Include: container.ListBlobsInclude{
//Metadata: true, //Metadata: true,
}, },
Prefix: &prefix, Prefix: &prefix,
MaxResults: &azureBlobDefaultPageSize,
}) })
for pager.More() { for pager.More() {
@ -615,7 +567,8 @@ func (fs *AzureBlobFs) GetDirSize(dirname string) (int, int64, error) {
Include: container.ListBlobsInclude{ Include: container.ListBlobsInclude{
Metadata: true, Metadata: true,
}, },
Prefix: &prefix, Prefix: &prefix,
MaxResults: &azureBlobDefaultPageSize,
}) })
for pager.More() { for pager.More() {
@ -684,7 +637,8 @@ func (fs *AzureBlobFs) Walk(root string, walkFn filepath.WalkFunc) error {
Include: container.ListBlobsInclude{ Include: container.ListBlobsInclude{
Metadata: true, Metadata: true,
}, },
Prefix: &prefix, Prefix: &prefix,
MaxResults: &azureBlobDefaultPageSize,
}) })
for pager.More() { for pager.More() {
@ -863,7 +817,7 @@ func (fs *AzureBlobFs) copyFileInternal(source, target string) error {
return nil return nil
} }
func (fs *AzureBlobFs) renameInternal(source, target string, fi os.FileInfo) (int, int64, error) { func (fs *AzureBlobFs) renameInternal(source, target string, fi os.FileInfo, recursion int) (int, int64, error) {
var numFiles int var numFiles int
var filesSize int64 var filesSize int64
@ -881,24 +835,12 @@ func (fs *AzureBlobFs) renameInternal(source, target string, fi os.FileInfo) (in
return numFiles, filesSize, err return numFiles, filesSize, err
} }
if renameMode == 1 { if renameMode == 1 {
entries, err := fs.ReadDir(source) files, size, err := doRecursiveRename(fs, source, target, fs.renameInternal, recursion)
numFiles += files
filesSize += size
if err != nil { if err != nil {
return numFiles, filesSize, err return numFiles, filesSize, err
} }
for _, info := range entries {
sourceEntry := fs.Join(source, info.Name())
targetEntry := fs.Join(target, info.Name())
files, size, err := fs.renameInternal(sourceEntry, targetEntry, info)
if err != nil {
if fs.IsNotExist(err) {
fsLog(fs, logger.LevelInfo, "skipping rename for %q: %v", sourceEntry, err)
continue
}
return numFiles, filesSize, err
}
numFiles += files
filesSize += size
}
} }
} else { } else {
if err := fs.copyFileInternal(source, target); err != nil { if err := fs.copyFileInternal(source, target); err != nil {
@ -1312,3 +1254,80 @@ func (b *bufferAllocator) free() {
b.available = nil b.available = nil
b.finalized = true b.finalized = true
} }
type azureBlobDirLister struct {
baseDirLister
paginator *runtime.Pager[container.ListBlobsHierarchyResponse]
timeout time.Duration
prefix string
prefixes map[string]bool
metricUpdated bool
}
func (l *azureBlobDirLister) Next(limit int) ([]os.FileInfo, error) {
if limit <= 0 {
return nil, errInvalidDirListerLimit
}
if len(l.cache) >= limit {
return l.returnFromCache(limit), nil
}
if !l.paginator.More() {
if !l.metricUpdated {
l.metricUpdated = true
metric.AZListObjectsCompleted(nil)
}
return l.returnFromCache(limit), io.EOF
}
ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(l.timeout))
defer cancelFn()
page, err := l.paginator.NextPage(ctx)
if err != nil {
metric.AZListObjectsCompleted(err)
return l.cache, err
}
for _, blobPrefix := range page.ListBlobsHierarchySegmentResponse.Segment.BlobPrefixes {
name := util.GetStringFromPointer(blobPrefix.Name)
// we don't support prefixes == "/" this will be sent if a key starts with "/"
if name == "" || name == "/" {
continue
}
// sometime we have duplicate prefixes, maybe an Azurite bug
name = strings.TrimPrefix(name, l.prefix)
if _, ok := l.prefixes[strings.TrimSuffix(name, "/")]; ok {
continue
}
l.cache = append(l.cache, NewFileInfo(name, true, 0, time.Unix(0, 0), false))
l.prefixes[strings.TrimSuffix(name, "/")] = true
}
for _, blobItem := range page.ListBlobsHierarchySegmentResponse.Segment.BlobItems {
name := util.GetStringFromPointer(blobItem.Name)
name = strings.TrimPrefix(name, l.prefix)
size := int64(0)
isDir := false
modTime := time.Unix(0, 0)
if blobItem.Properties != nil {
size = util.GetIntFromPointer(blobItem.Properties.ContentLength)
modTime = util.GetTimeFromPointer(blobItem.Properties.LastModified)
contentType := util.GetStringFromPointer(blobItem.Properties.ContentType)
isDir = checkDirectoryMarkers(contentType, blobItem.Metadata)
if isDir {
// check if the dir is already included, it will be sent as blob prefix if it contains at least one item
if _, ok := l.prefixes[name]; ok {
continue
}
l.prefixes[name] = true
}
}
l.cache = append(l.cache, NewFileInfo(name, isDir, size, modTime, false))
}
return l.returnFromCache(limit), nil
}
func (l *azureBlobDirLister) Close() error {
clear(l.prefixes)
return l.baseDirLister.Close()
}

View file

@ -221,21 +221,16 @@ func (*CryptFs) Truncate(_ string, _ int64) error {
// ReadDir reads the directory named by dirname and returns // ReadDir reads the directory named by dirname and returns
// a list of directory entries. // a list of directory entries.
func (fs *CryptFs) ReadDir(dirname string) ([]os.FileInfo, error) { func (fs *CryptFs) ReadDir(dirname string) (DirLister, error) {
f, err := os.Open(dirname) f, err := os.Open(dirname)
if err != nil { if err != nil {
if isInvalidNameError(err) {
err = os.ErrNotExist
}
return nil, err return nil, err
} }
list, err := f.Readdir(-1)
f.Close() return &cryptFsDirLister{f}, nil
if err != nil {
return nil, err
}
result := make([]os.FileInfo, 0, len(list))
for _, info := range list {
result = append(result, fs.ConvertFileInfo(info))
}
return result, nil
} }
// IsUploadResumeSupported returns false sio does not support random access writes // IsUploadResumeSupported returns false sio does not support random access writes
@ -289,20 +284,7 @@ func (fs *CryptFs) getSIOConfig(key [32]byte) sio.Config {
// ConvertFileInfo returns a FileInfo with the decrypted size // ConvertFileInfo returns a FileInfo with the decrypted size
func (fs *CryptFs) ConvertFileInfo(info os.FileInfo) os.FileInfo { func (fs *CryptFs) ConvertFileInfo(info os.FileInfo) os.FileInfo {
if !info.Mode().IsRegular() { return convertCryptFsInfo(info)
return info
}
size := info.Size()
if size >= headerV10Size {
size -= headerV10Size
decryptedSize, err := sio.DecryptedSize(uint64(size))
if err == nil {
size = int64(decryptedSize)
}
} else {
size = 0
}
return NewFileInfo(info.Name(), info.IsDir(), size, info.ModTime(), false)
} }
func (fs *CryptFs) getFileAndEncryptionKey(name string) (*os.File, [32]byte, error) { func (fs *CryptFs) getFileAndEncryptionKey(name string) (*os.File, [32]byte, error) {
@ -366,6 +348,23 @@ func isZeroBytesDownload(f *os.File, offset int64) (bool, error) {
return false, nil return false, nil
} }
func convertCryptFsInfo(info os.FileInfo) os.FileInfo {
if !info.Mode().IsRegular() {
return info
}
size := info.Size()
if size >= headerV10Size {
size -= headerV10Size
decryptedSize, err := sio.DecryptedSize(uint64(size))
if err == nil {
size = int64(decryptedSize)
}
} else {
size = 0
}
return NewFileInfo(info.Name(), info.IsDir(), size, info.ModTime(), false)
}
type encryptedFileHeader struct { type encryptedFileHeader struct {
version byte version byte
nonce []byte nonce []byte
@ -400,3 +399,22 @@ type cryptedFileWrapper struct {
func (w *cryptedFileWrapper) ReadAt(p []byte, offset int64) (n int, err error) { func (w *cryptedFileWrapper) ReadAt(p []byte, offset int64) (n int, err error) {
return w.File.ReadAt(p, offset+headerV10Size) return w.File.ReadAt(p, offset+headerV10Size)
} }
type cryptFsDirLister struct {
f *os.File
}
func (l *cryptFsDirLister) Next(limit int) ([]os.FileInfo, error) {
if limit <= 0 {
return nil, errInvalidDirListerLimit
}
files, err := l.f.Readdir(limit)
for idx := range files {
files[idx] = convertCryptFsInfo(files[idx])
}
return files, err
}
func (l *cryptFsDirLister) Close() error {
return l.f.Close()
}

View file

@ -266,7 +266,7 @@ func (fs *GCSFs) Rename(source, target string) (int, int64, error) {
if err != nil { if err != nil {
return -1, -1, err return -1, -1, err
} }
return fs.renameInternal(source, target, fi) return fs.renameInternal(source, target, fi, 0)
} }
// Remove removes the named file or (empty) directory. // Remove removes the named file or (empty) directory.
@ -369,80 +369,23 @@ func (*GCSFs) Truncate(_ string, _ int64) error {
// ReadDir reads the directory named by dirname and returns // ReadDir reads the directory named by dirname and returns
// a list of directory entries. // a list of directory entries.
func (fs *GCSFs) ReadDir(dirname string) ([]os.FileInfo, error) { func (fs *GCSFs) ReadDir(dirname string) (DirLister, error) {
var result []os.FileInfo
// dirname must be already cleaned // dirname must be already cleaned
prefix := fs.getPrefix(dirname) prefix := fs.getPrefix(dirname)
query := &storage.Query{Prefix: prefix, Delimiter: "/"} query := &storage.Query{Prefix: prefix, Delimiter: "/"}
err := query.SetAttrSelection(gcsDefaultFieldsSelection) err := query.SetAttrSelection(gcsDefaultFieldsSelection)
if err != nil { if err != nil {
return nil, err return nil, err
} }
modTimes, err := getFolderModTimes(fs.getStorageID(), dirname)
if err != nil {
return result, err
}
prefixes := make(map[string]bool)
ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxLongTimeout))
defer cancelFn()
bkt := fs.svc.Bucket(fs.config.Bucket) bkt := fs.svc.Bucket(fs.config.Bucket)
it := bkt.Objects(ctx, query)
pager := iterator.NewPager(it, defaultGCSPageSize, "")
for { return &gcsDirLister{
var objects []*storage.ObjectAttrs bucket: bkt,
pageToken, err := pager.NextPage(&objects) query: query,
if err != nil { timeout: fs.ctxTimeout,
metric.GCSListObjectsCompleted(err) prefix: prefix,
return result, err prefixes: make(map[string]bool),
} }, nil
for _, attrs := range objects {
if attrs.Prefix != "" {
name, _ := fs.resolve(attrs.Prefix, prefix, attrs.ContentType)
if name == "" {
continue
}
if _, ok := prefixes[name]; ok {
continue
}
result = append(result, NewFileInfo(name, true, 0, time.Unix(0, 0), false))
prefixes[name] = true
} else {
name, isDir := fs.resolve(attrs.Name, prefix, attrs.ContentType)
if name == "" {
continue
}
if !attrs.Deleted.IsZero() {
continue
}
if isDir {
// check if the dir is already included, it will be sent as blob prefix if it contains at least one item
if _, ok := prefixes[name]; ok {
continue
}
prefixes[name] = true
}
modTime := attrs.Updated
if t, ok := modTimes[name]; ok {
modTime = util.GetTimeFromMsecSinceEpoch(t)
}
result = append(result, NewFileInfo(name, isDir, attrs.Size, modTime, false))
}
}
objects = nil
if pageToken == "" {
break
}
}
metric.GCSListObjectsCompleted(nil)
return result, nil
} }
// IsUploadResumeSupported returns true if resuming uploads is supported. // IsUploadResumeSupported returns true if resuming uploads is supported.
@ -853,7 +796,7 @@ func (fs *GCSFs) copyFileInternal(source, target string) error {
return err return err
} }
func (fs *GCSFs) renameInternal(source, target string, fi os.FileInfo) (int, int64, error) { func (fs *GCSFs) renameInternal(source, target string, fi os.FileInfo, recursion int) (int, int64, error) {
var numFiles int var numFiles int
var filesSize int64 var filesSize int64
@ -871,24 +814,12 @@ func (fs *GCSFs) renameInternal(source, target string, fi os.FileInfo) (int, int
return numFiles, filesSize, err return numFiles, filesSize, err
} }
if renameMode == 1 { if renameMode == 1 {
entries, err := fs.ReadDir(source) files, size, err := doRecursiveRename(fs, source, target, fs.renameInternal, recursion)
numFiles += files
filesSize += size
if err != nil { if err != nil {
return numFiles, filesSize, err return numFiles, filesSize, err
} }
for _, info := range entries {
sourceEntry := fs.Join(source, info.Name())
targetEntry := fs.Join(target, info.Name())
files, size, err := fs.renameInternal(sourceEntry, targetEntry, info)
if err != nil {
if fs.IsNotExist(err) {
fsLog(fs, logger.LevelInfo, "skipping rename for %q: %v", sourceEntry, err)
continue
}
return numFiles, filesSize, err
}
numFiles += files
filesSize += size
}
} }
} else { } else {
if err := fs.copyFileInternal(source, target); err != nil { if err := fs.copyFileInternal(source, target); err != nil {
@ -1010,3 +941,97 @@ func (*GCSFs) getTempObject(name string) string {
func (fs *GCSFs) getStorageID() string { func (fs *GCSFs) getStorageID() string {
return fmt.Sprintf("gs://%v", fs.config.Bucket) return fmt.Sprintf("gs://%v", fs.config.Bucket)
} }
type gcsDirLister struct {
baseDirLister
bucket *storage.BucketHandle
query *storage.Query
timeout time.Duration
nextPageToken string
noMorePages bool
prefix string
prefixes map[string]bool
metricUpdated bool
}
func (l *gcsDirLister) resolve(name, contentType string) (string, bool) {
result := strings.TrimPrefix(name, l.prefix)
isDir := strings.HasSuffix(result, "/")
if isDir {
result = strings.TrimSuffix(result, "/")
}
if contentType == dirMimeType {
isDir = true
}
return result, isDir
}
func (l *gcsDirLister) Next(limit int) ([]os.FileInfo, error) {
if limit <= 0 {
return nil, errInvalidDirListerLimit
}
if len(l.cache) >= limit {
return l.returnFromCache(limit), nil
}
if l.noMorePages {
if !l.metricUpdated {
l.metricUpdated = true
metric.GCSListObjectsCompleted(nil)
}
return l.returnFromCache(limit), io.EOF
}
ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(l.timeout))
defer cancelFn()
it := l.bucket.Objects(ctx, l.query)
paginator := iterator.NewPager(it, defaultGCSPageSize, l.nextPageToken)
var objects []*storage.ObjectAttrs
pageToken, err := paginator.NextPage(&objects)
if err != nil {
metric.GCSListObjectsCompleted(err)
return l.cache, err
}
for _, attrs := range objects {
if attrs.Prefix != "" {
name, _ := l.resolve(attrs.Prefix, attrs.ContentType)
if name == "" {
continue
}
if _, ok := l.prefixes[name]; ok {
continue
}
l.cache = append(l.cache, NewFileInfo(name, true, 0, time.Unix(0, 0), false))
l.prefixes[name] = true
} else {
name, isDir := l.resolve(attrs.Name, attrs.ContentType)
if name == "" {
continue
}
if !attrs.Deleted.IsZero() {
continue
}
if isDir {
// check if the dir is already included, it will be sent as blob prefix if it contains at least one item
if _, ok := l.prefixes[name]; ok {
continue
}
l.prefixes[name] = true
}
l.cache = append(l.cache, NewFileInfo(name, isDir, attrs.Size, attrs.Updated, false))
}
}
l.nextPageToken = pageToken
l.noMorePages = (l.nextPageToken == "")
return l.returnFromCache(limit), nil
}
func (l *gcsDirLister) Close() error {
clear(l.prefixes)
return l.baseDirLister.Close()
}

View file

@ -488,7 +488,7 @@ func (fs *HTTPFs) Truncate(name string, size int64) error {
// ReadDir reads the directory named by dirname and returns // ReadDir reads the directory named by dirname and returns
// a list of directory entries. // a list of directory entries.
func (fs *HTTPFs) ReadDir(dirname string) ([]os.FileInfo, error) { func (fs *HTTPFs) ReadDir(dirname string) (DirLister, error) {
ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout))
defer cancelFn() defer cancelFn()
@ -511,7 +511,7 @@ func (fs *HTTPFs) ReadDir(dirname string) ([]os.FileInfo, error) {
for _, stat := range response { for _, stat := range response {
result = append(result, stat.getFileInfo()) result = append(result, stat.getFileInfo())
} }
return result, nil return &baseDirLister{result}, nil
} }
// IsUploadResumeSupported returns true if resuming uploads is supported. // IsUploadResumeSupported returns true if resuming uploads is supported.
@ -731,19 +731,33 @@ func (fs *HTTPFs) walk(filePath string, info fs.FileInfo, walkFn filepath.WalkFu
if !info.IsDir() { if !info.IsDir() {
return walkFn(filePath, info, nil) return walkFn(filePath, info, nil)
} }
files, err := fs.ReadDir(filePath) lister, err := fs.ReadDir(filePath)
err1 := walkFn(filePath, info, err) err1 := walkFn(filePath, info, err)
if err != nil || err1 != nil { if err != nil || err1 != nil {
if err == nil {
lister.Close()
}
return err1 return err1
} }
for _, fi := range files { defer lister.Close()
objName := path.Join(filePath, fi.Name())
err = fs.walk(objName, fi, walkFn) for {
if err != nil { files, err := lister.Next(ListerBatchSize)
finished := errors.Is(err, io.EOF)
if err != nil && !finished {
return err return err
} }
for _, fi := range files {
objName := path.Join(filePath, fi.Name())
err = fs.walk(objName, fi, walkFn)
if err != nil {
return err
}
}
if finished {
return nil
}
} }
return nil
} }
func getErrorFromResponseCode(code int) error { func getErrorFromResponseCode(code int) error {

View file

@ -190,7 +190,7 @@ func (fs *OsFs) Rename(source, target string) (int, int64, error) {
} }
err = fscopy.Copy(source, target, fscopy.Options{ err = fscopy.Copy(source, target, fscopy.Options{
OnSymlink: func(src string) fscopy.SymlinkAction { OnSymlink: func(_ string) fscopy.SymlinkAction {
return fscopy.Skip return fscopy.Skip
}, },
CopyBufferSize: readBufferSize, CopyBufferSize: readBufferSize,
@ -258,7 +258,7 @@ func (*OsFs) Truncate(name string, size int64) error {
// ReadDir reads the directory named by dirname and returns // ReadDir reads the directory named by dirname and returns
// a list of directory entries. // a list of directory entries.
func (*OsFs) ReadDir(dirname string) ([]os.FileInfo, error) { func (*OsFs) ReadDir(dirname string) (DirLister, error) {
f, err := os.Open(dirname) f, err := os.Open(dirname)
if err != nil { if err != nil {
if isInvalidNameError(err) { if isInvalidNameError(err) {
@ -266,12 +266,7 @@ func (*OsFs) ReadDir(dirname string) ([]os.FileInfo, error) {
} }
return nil, err return nil, err
} }
list, err := f.Readdir(-1) return &osFsDirLister{f}, nil
f.Close()
if err != nil {
return nil, err
}
return list, nil
} }
// IsUploadResumeSupported returns true if resuming uploads is supported // IsUploadResumeSupported returns true if resuming uploads is supported
@ -599,3 +594,18 @@ func (fs *OsFs) useWriteBuffering(flag int) bool {
} }
return true return true
} }
type osFsDirLister struct {
f *os.File
}
func (l *osFsDirLister) Next(limit int) ([]os.FileInfo, error) {
if limit <= 0 {
return nil, errInvalidDirListerLimit
}
return l.f.Readdir(limit)
}
func (l *osFsDirLister) Close() error {
return l.f.Close()
}

View file

@ -22,6 +22,7 @@ import (
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt" "fmt"
"io"
"mime" "mime"
"net" "net"
"net/http" "net/http"
@ -61,7 +62,8 @@ const (
) )
var ( var (
s3DirMimeTypes = []string{s3DirMimeType, "httpd/unix-directory"} s3DirMimeTypes = []string{s3DirMimeType, "httpd/unix-directory"}
s3DefaultPageSize = int32(5000)
) )
// S3Fs is a Fs implementation for AWS S3 compatible object storages // S3Fs is a Fs implementation for AWS S3 compatible object storages
@ -337,7 +339,7 @@ func (fs *S3Fs) Rename(source, target string) (int, int64, error) {
if err != nil { if err != nil {
return -1, -1, err return -1, -1, err
} }
return fs.renameInternal(source, target, fi) return fs.renameInternal(source, target, fi, 0)
} }
// Remove removes the named file or (empty) directory. // Remove removes the named file or (empty) directory.
@ -426,68 +428,22 @@ func (*S3Fs) Truncate(_ string, _ int64) error {
// ReadDir reads the directory named by dirname and returns // ReadDir reads the directory named by dirname and returns
// a list of directory entries. // a list of directory entries.
func (fs *S3Fs) ReadDir(dirname string) ([]os.FileInfo, error) { func (fs *S3Fs) ReadDir(dirname string) (DirLister, error) {
var result []os.FileInfo
// dirname must be already cleaned // dirname must be already cleaned
prefix := fs.getPrefix(dirname) prefix := fs.getPrefix(dirname)
modTimes, err := getFolderModTimes(fs.getStorageID(), dirname)
if err != nil {
return result, err
}
prefixes := make(map[string]bool)
paginator := s3.NewListObjectsV2Paginator(fs.svc, &s3.ListObjectsV2Input{ paginator := s3.NewListObjectsV2Paginator(fs.svc, &s3.ListObjectsV2Input{
Bucket: aws.String(fs.config.Bucket), Bucket: aws.String(fs.config.Bucket),
Prefix: aws.String(prefix), Prefix: aws.String(prefix),
Delimiter: aws.String("/"), Delimiter: aws.String("/"),
MaxKeys: &s3DefaultPageSize,
}) })
for paginator.HasMorePages() { return &s3DirLister{
ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) paginator: paginator,
defer cancelFn() timeout: fs.ctxTimeout,
prefix: prefix,
page, err := paginator.NextPage(ctx) prefixes: make(map[string]bool),
if err != nil { }, nil
metric.S3ListObjectsCompleted(err)
return result, err
}
for _, p := range page.CommonPrefixes {
// prefixes have a trailing slash
name, _ := fs.resolve(p.Prefix, prefix)
if name == "" {
continue
}
if _, ok := prefixes[name]; ok {
continue
}
result = append(result, NewFileInfo(name, true, 0, time.Unix(0, 0), false))
prefixes[name] = true
}
for _, fileObject := range page.Contents {
objectModTime := util.GetTimeFromPointer(fileObject.LastModified)
objectSize := util.GetIntFromPointer(fileObject.Size)
name, isDir := fs.resolve(fileObject.Key, prefix)
if name == "" || name == "/" {
continue
}
if isDir {
if _, ok := prefixes[name]; ok {
continue
}
prefixes[name] = true
}
if t, ok := modTimes[name]; ok {
objectModTime = util.GetTimeFromMsecSinceEpoch(t)
}
result = append(result, NewFileInfo(name, (isDir && objectSize == 0), objectSize,
objectModTime, false))
}
}
metric.S3ListObjectsCompleted(nil)
return result, nil
} }
// IsUploadResumeSupported returns true if resuming uploads is supported. // IsUploadResumeSupported returns true if resuming uploads is supported.
@ -574,6 +530,7 @@ func (fs *S3Fs) getFileNamesInPrefix(fsPrefix string) (map[string]bool, error) {
Bucket: aws.String(fs.config.Bucket), Bucket: aws.String(fs.config.Bucket),
Prefix: aws.String(prefix), Prefix: aws.String(prefix),
Delimiter: aws.String("/"), Delimiter: aws.String("/"),
MaxKeys: &s3DefaultPageSize,
}) })
for paginator.HasMorePages() { for paginator.HasMorePages() {
@ -614,8 +571,9 @@ func (fs *S3Fs) GetDirSize(dirname string) (int, int64, error) {
size := int64(0) size := int64(0)
paginator := s3.NewListObjectsV2Paginator(fs.svc, &s3.ListObjectsV2Input{ paginator := s3.NewListObjectsV2Paginator(fs.svc, &s3.ListObjectsV2Input{
Bucket: aws.String(fs.config.Bucket), Bucket: aws.String(fs.config.Bucket),
Prefix: aws.String(prefix), Prefix: aws.String(prefix),
MaxKeys: &s3DefaultPageSize,
}) })
for paginator.HasMorePages() { for paginator.HasMorePages() {
@ -679,8 +637,9 @@ func (fs *S3Fs) Walk(root string, walkFn filepath.WalkFunc) error {
prefix := fs.getPrefix(root) prefix := fs.getPrefix(root)
paginator := s3.NewListObjectsV2Paginator(fs.svc, &s3.ListObjectsV2Input{ paginator := s3.NewListObjectsV2Paginator(fs.svc, &s3.ListObjectsV2Input{
Bucket: aws.String(fs.config.Bucket), Bucket: aws.String(fs.config.Bucket),
Prefix: aws.String(prefix), Prefix: aws.String(prefix),
MaxKeys: &s3DefaultPageSize,
}) })
for paginator.HasMorePages() { for paginator.HasMorePages() {
@ -797,7 +756,7 @@ func (fs *S3Fs) copyFileInternal(source, target string, fileSize int64) error {
return err return err
} }
func (fs *S3Fs) renameInternal(source, target string, fi os.FileInfo) (int, int64, error) { func (fs *S3Fs) renameInternal(source, target string, fi os.FileInfo, recursion int) (int, int64, error) {
var numFiles int var numFiles int
var filesSize int64 var filesSize int64
@ -815,24 +774,12 @@ func (fs *S3Fs) renameInternal(source, target string, fi os.FileInfo) (int, int6
return numFiles, filesSize, err return numFiles, filesSize, err
} }
if renameMode == 1 { if renameMode == 1 {
entries, err := fs.ReadDir(source) files, size, err := doRecursiveRename(fs, source, target, fs.renameInternal, recursion)
numFiles += files
filesSize += size
if err != nil { if err != nil {
return numFiles, filesSize, err return numFiles, filesSize, err
} }
for _, info := range entries {
sourceEntry := fs.Join(source, info.Name())
targetEntry := fs.Join(target, info.Name())
files, size, err := fs.renameInternal(sourceEntry, targetEntry, info)
if err != nil {
if fs.IsNotExist(err) {
fsLog(fs, logger.LevelInfo, "skipping rename for %q: %v", sourceEntry, err)
continue
}
return numFiles, filesSize, err
}
numFiles += files
filesSize += size
}
} }
} else { } else {
if err := fs.copyFileInternal(source, target, fi.Size()); err != nil { if err := fs.copyFileInternal(source, target, fi.Size()); err != nil {
@ -1114,6 +1061,81 @@ func (fs *S3Fs) getStorageID() string {
return fmt.Sprintf("s3://%v", fs.config.Bucket) return fmt.Sprintf("s3://%v", fs.config.Bucket)
} }
type s3DirLister struct {
baseDirLister
paginator *s3.ListObjectsV2Paginator
timeout time.Duration
prefix string
prefixes map[string]bool
metricUpdated bool
}
func (l *s3DirLister) resolve(name *string) (string, bool) {
result := strings.TrimPrefix(util.GetStringFromPointer(name), l.prefix)
isDir := strings.HasSuffix(result, "/")
if isDir {
result = strings.TrimSuffix(result, "/")
}
return result, isDir
}
func (l *s3DirLister) Next(limit int) ([]os.FileInfo, error) {
if limit <= 0 {
return nil, errInvalidDirListerLimit
}
if len(l.cache) >= limit {
return l.returnFromCache(limit), nil
}
if !l.paginator.HasMorePages() {
if !l.metricUpdated {
l.metricUpdated = true
metric.S3ListObjectsCompleted(nil)
}
return l.returnFromCache(limit), io.EOF
}
ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(l.timeout))
defer cancelFn()
page, err := l.paginator.NextPage(ctx)
if err != nil {
metric.S3ListObjectsCompleted(err)
return l.cache, err
}
for _, p := range page.CommonPrefixes {
// prefixes have a trailing slash
name, _ := l.resolve(p.Prefix)
if name == "" {
continue
}
if _, ok := l.prefixes[name]; ok {
continue
}
l.cache = append(l.cache, NewFileInfo(name, true, 0, time.Unix(0, 0), false))
l.prefixes[name] = true
}
for _, fileObject := range page.Contents {
objectModTime := util.GetTimeFromPointer(fileObject.LastModified)
objectSize := util.GetIntFromPointer(fileObject.Size)
name, isDir := l.resolve(fileObject.Key)
if name == "" || name == "/" {
continue
}
if isDir {
if _, ok := l.prefixes[name]; ok {
continue
}
l.prefixes[name] = true
}
l.cache = append(l.cache, NewFileInfo(name, (isDir && objectSize == 0), objectSize, objectModTime, false))
}
return l.returnFromCache(limit), nil
}
func (l *s3DirLister) Close() error {
return l.baseDirLister.Close()
}
func getAWSHTTPClient(timeout int, idleConnectionTimeout time.Duration, skipTLSVerify bool) *awshttp.BuildableClient { func getAWSHTTPClient(timeout int, idleConnectionTimeout time.Duration, skipTLSVerify bool) *awshttp.BuildableClient {
c := awshttp.NewBuildableClient(). c := awshttp.NewBuildableClient().
WithDialerOptions(func(d *net.Dialer) { WithDialerOptions(func(d *net.Dialer) {

View file

@ -541,12 +541,16 @@ func (fs *SFTPFs) Truncate(name string, size int64) error {
// ReadDir reads the directory named by dirname and returns // ReadDir reads the directory named by dirname and returns
// a list of directory entries. // a list of directory entries.
func (fs *SFTPFs) ReadDir(dirname string) ([]os.FileInfo, error) { func (fs *SFTPFs) ReadDir(dirname string) (DirLister, error) {
client, err := fs.conn.getClient() client, err := fs.conn.getClient()
if err != nil { if err != nil {
return nil, err return nil, err
} }
return client.ReadDir(dirname) files, err := client.ReadDir(dirname)
if err != nil {
return nil, err
}
return &baseDirLister{files}, nil
} }
// IsUploadResumeSupported returns true if resuming uploads is supported. // IsUploadResumeSupported returns true if resuming uploads is supported.

View file

@ -45,6 +45,8 @@ const (
gcsfsName = "GCSFs" gcsfsName = "GCSFs"
azBlobFsName = "AzureBlobFs" azBlobFsName = "AzureBlobFs"
preResumeTimeout = 90 * time.Second preResumeTimeout = 90 * time.Second
// ListerBatchSize defines the default limit for DirLister implementations
ListerBatchSize = 1000
) )
// Additional checks for files // Additional checks for files
@ -58,14 +60,15 @@ var (
// ErrStorageSizeUnavailable is returned if the storage backend does not support getting the size // ErrStorageSizeUnavailable is returned if the storage backend does not support getting the size
ErrStorageSizeUnavailable = errors.New("unable to get available size for this storage backend") ErrStorageSizeUnavailable = errors.New("unable to get available size for this storage backend")
// ErrVfsUnsupported defines the error for an unsupported VFS operation // ErrVfsUnsupported defines the error for an unsupported VFS operation
ErrVfsUnsupported = errors.New("not supported") ErrVfsUnsupported = errors.New("not supported")
tempPath string errInvalidDirListerLimit = errors.New("dir lister: invalid limit, must be > 0")
sftpFingerprints []string tempPath string
allowSelfConnections int sftpFingerprints []string
renameMode int allowSelfConnections int
readMetadata int renameMode int
resumeMaxSize int64 readMetadata int
uploadMode int resumeMaxSize int64
uploadMode int
) )
// SetAllowSelfConnections sets the desired behaviour for self connections // SetAllowSelfConnections sets the desired behaviour for self connections
@ -125,7 +128,7 @@ type Fs interface {
Chmod(name string, mode os.FileMode) error Chmod(name string, mode os.FileMode) error
Chtimes(name string, atime, mtime time.Time, isUploading bool) error Chtimes(name string, atime, mtime time.Time, isUploading bool) error
Truncate(name string, size int64) error Truncate(name string, size int64) error
ReadDir(dirname string) ([]os.FileInfo, error) ReadDir(dirname string) (DirLister, error)
Readlink(name string) (string, error) Readlink(name string) (string, error)
IsUploadResumeSupported() bool IsUploadResumeSupported() bool
IsConditionalUploadResumeSupported(size int64) bool IsConditionalUploadResumeSupported(size int64) bool
@ -199,11 +202,47 @@ type PipeReader interface {
Metadata() map[string]string Metadata() map[string]string
} }
// DirLister defines an interface for a directory lister
type DirLister interface {
Next(limit int) ([]os.FileInfo, error)
Close() error
}
// Metadater defines an interface to implement to return metadata for a file // Metadater defines an interface to implement to return metadata for a file
type Metadater interface { type Metadater interface {
Metadata() map[string]string Metadata() map[string]string
} }
type baseDirLister struct {
cache []os.FileInfo
}
func (l *baseDirLister) Next(limit int) ([]os.FileInfo, error) {
if limit <= 0 {
return nil, errInvalidDirListerLimit
}
if len(l.cache) >= limit {
return l.returnFromCache(limit), nil
}
return l.returnFromCache(limit), io.EOF
}
func (l *baseDirLister) returnFromCache(limit int) []os.FileInfo {
if len(l.cache) >= limit {
result := l.cache[:limit]
l.cache = l.cache[limit:]
return result
}
result := l.cache
l.cache = nil
return result
}
func (l *baseDirLister) Close() error {
l.cache = nil
return nil
}
// QuotaCheckResult defines the result for a quota check // QuotaCheckResult defines the result for a quota check
type QuotaCheckResult struct { type QuotaCheckResult struct {
HasSpace bool HasSpace bool
@ -1069,18 +1108,6 @@ func updateFileInfoModTime(storageID, objectPath string, info *FileInfo) (*FileI
return info, nil return info, nil
} }
func getFolderModTimes(storageID, dirName string) (map[string]int64, error) {
var err error
modTimes := make(map[string]int64)
if plugin.Handler.HasMetadater() {
modTimes, err = plugin.Handler.GetModificationTimes(storageID, ensureAbsPath(dirName))
if err != nil && !errors.Is(err, metadata.ErrNoSuchObject) {
return modTimes, err
}
}
return modTimes, nil
}
func ensureAbsPath(name string) string { func ensureAbsPath(name string) string {
if path.IsAbs(name) { if path.IsAbs(name) {
return name return name
@ -1205,6 +1232,50 @@ func getLocalTempDir() string {
return filepath.Clean(os.TempDir()) return filepath.Clean(os.TempDir())
} }
func doRecursiveRename(fs Fs, source, target string,
renameFn func(string, string, os.FileInfo, int) (int, int64, error),
recursion int,
) (int, int64, error) {
var numFiles int
var filesSize int64
if recursion > util.MaxRecursion {
return numFiles, filesSize, util.ErrRecursionTooDeep
}
recursion++
lister, err := fs.ReadDir(source)
if err != nil {
return numFiles, filesSize, err
}
defer lister.Close()
for {
entries, err := lister.Next(ListerBatchSize)
finished := errors.Is(err, io.EOF)
if err != nil && !finished {
return numFiles, filesSize, err
}
for _, info := range entries {
sourceEntry := fs.Join(source, info.Name())
targetEntry := fs.Join(target, info.Name())
files, size, err := renameFn(sourceEntry, targetEntry, info, recursion)
if err != nil {
if fs.IsNotExist(err) {
fsLog(fs, logger.LevelInfo, "skipping rename for %q: %v", sourceEntry, err)
continue
}
return numFiles, filesSize, err
}
numFiles += files
filesSize += size
}
if finished {
return numFiles, filesSize, nil
}
}
}
func fsLog(fs Fs, level logger.LogLevel, format string, v ...any) { func fsLog(fs Fs, level logger.LogLevel, format string, v ...any) {
logger.Log(level, fs.Name(), fs.ConnectionID(), format, v...) logger.Log(level, fs.Name(), fs.ConnectionID(), format, v...)
} }

View file

@ -108,22 +108,24 @@ func (fi *webDavFileInfo) ContentType(_ context.Context) (string, error) {
// Readdir reads directory entries from the handle // Readdir reads directory entries from the handle
func (f *webDavFile) Readdir(_ int) ([]os.FileInfo, error) { func (f *webDavFile) Readdir(_ int) ([]os.FileInfo, error) {
return nil, webdav.ErrNotImplemented
}
// ReadDir implements the FileDirLister interface
func (f *webDavFile) ReadDir() (webdav.DirLister, error) {
if !f.Connection.User.HasPerm(dataprovider.PermListItems, f.GetVirtualPath()) { if !f.Connection.User.HasPerm(dataprovider.PermListItems, f.GetVirtualPath()) {
return nil, f.Connection.GetPermissionDeniedError() return nil, f.Connection.GetPermissionDeniedError()
} }
entries, err := f.Connection.ListDir(f.GetVirtualPath()) lister, err := f.Connection.ListDir(f.GetVirtualPath())
if err != nil { if err != nil {
return nil, err return nil, err
} }
for idx, info := range entries { return &webDavDirLister{
entries[idx] = &webDavFileInfo{ DirLister: lister,
FileInfo: info, fs: f.Fs,
Fs: f.Fs, virtualDirPath: f.GetVirtualPath(),
virtualPath: path.Join(f.GetVirtualPath(), info.Name()), fsDirPath: f.GetFsPath(),
fsPath: f.Fs.Join(f.GetFsPath(), info.Name()), }, nil
}
}
return entries, nil
} }
// Stat the handle // Stat the handle
@ -474,3 +476,24 @@ func (f *webDavFile) Patch(patches []webdav.Proppatch) ([]webdav.Propstat, error
} }
return resp, nil return resp, nil
} }
type webDavDirLister struct {
vfs.DirLister
fs vfs.Fs
virtualDirPath string
fsDirPath string
}
func (l *webDavDirLister) Next(limit int) ([]os.FileInfo, error) {
files, err := l.DirLister.Next(limit)
for idx := range files {
info := files[idx]
files[idx] = &webDavFileInfo{
FileInfo: info,
Fs: l.fs,
virtualPath: path.Join(l.virtualDirPath, info.Name()),
fsPath: l.fs.Join(l.fsDirPath, info.Name()),
}
}
return files, err
}

View file

@ -692,6 +692,8 @@ func TestContentType(t *testing.T) {
assert.Equal(t, "application/custom-mime", ctype) assert.Equal(t, "application/custom-mime", ctype)
} }
_, err = davFile.Readdir(-1) _, err = davFile.Readdir(-1)
assert.ErrorIs(t, err, webdav.ErrNotImplemented)
_, err = davFile.ReadDir()
assert.Error(t, err) assert.Error(t, err)
err = davFile.Close() err = davFile.Close()
assert.NoError(t, err) assert.NoError(t, err)