add DirLister interface

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
Nicola Murino 2024-02-15 20:53:56 +01:00
parent c60eb050ef
commit 1ff55bbfa7
No known key found for this signature in database
GPG key ID: 935D2952DEC4EECF
35 changed files with 1362 additions and 669 deletions

11
go.mod
View file

@ -21,7 +21,7 @@ require (
github.com/bmatcuk/doublestar/v4 v4.6.1
github.com/cockroachdb/cockroach-go/v2 v2.3.6
github.com/coreos/go-oidc/v3 v3.9.0
github.com/drakkan/webdav v0.0.0-20230227175313-32996838bcd8
github.com/drakkan/webdav v0.0.0-20240212101318-94e905cb9adb
github.com/eikenb/pipeat v0.0.0-20210730190139-06b3e6902001
github.com/fclairamb/ftpserverlib v0.22.0
github.com/fclairamb/go-log v0.4.1
@ -71,7 +71,7 @@ require (
golang.org/x/crypto v0.18.0
golang.org/x/net v0.20.0
golang.org/x/oauth2 v0.16.0
golang.org/x/sys v0.16.0
golang.org/x/sys v0.17.0
golang.org/x/term v0.16.0
golang.org/x/time v0.5.0
google.golang.org/api v0.161.0
@ -118,7 +118,9 @@ require (
github.com/google/s2a-go v0.1.7 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
github.com/googleapis/gax-go/v2 v2.12.0 // indirect
github.com/hashicorp/errwrap v1.0.0 // indirect
github.com/hashicorp/go-cleanhttp v0.5.2 // indirect
github.com/hashicorp/go-multierror v1.1.1 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/hashicorp/yamux v0.1.1 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
@ -181,8 +183,9 @@ require (
)
replace (
github.com/fclairamb/ftpserverlib => github.com/drakkan/ftpserverlib v0.0.0-20230820193955-e7243edeb89b
github.com/jlaffaye/ftp => github.com/drakkan/ftp v0.0.0-20201114075148-9b9adce499a9
github.com/fclairamb/ftpserverlib => github.com/drakkan/ftpserverlib v0.0.0-20240212100826-a241365cb085
github.com/jlaffaye/ftp => github.com/drakkan/ftp v0.0.0-20240210102745-f1ffc43f78d2
github.com/pkg/sftp => github.com/drakkan/sftp v0.0.0-20240214104840-fbb0b8bdb30c
github.com/robfig/cron/v3 => github.com/drakkan/cron/v3 v3.0.0-20230222140221-217a1e4d96c0
golang.org/x/crypto => github.com/drakkan/crypto v0.0.0-20231218163632-74b52eafd2c0
)

23
go.sum
View file

@ -113,12 +113,14 @@ github.com/drakkan/cron/v3 v3.0.0-20230222140221-217a1e4d96c0 h1:EW9gIJRmt9lzk66
github.com/drakkan/cron/v3 v3.0.0-20230222140221-217a1e4d96c0/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
github.com/drakkan/crypto v0.0.0-20231218163632-74b52eafd2c0 h1:Yel8NcrK4jg+biIcTxnszKh0eIpF2Vj25XEygQcTweI=
github.com/drakkan/crypto v0.0.0-20231218163632-74b52eafd2c0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
github.com/drakkan/ftp v0.0.0-20201114075148-9b9adce499a9 h1:LPH1dEblAOO/LoG7yHPMtBLXhQmjaga91/DDjWk9jWA=
github.com/drakkan/ftp v0.0.0-20201114075148-9b9adce499a9/go.mod h1:2lmrmq866uF2tnje75wQHzmPXhmSWUt7Gyx2vgK1RCU=
github.com/drakkan/ftpserverlib v0.0.0-20230820193955-e7243edeb89b h1:sCtiYerLxfOQrSludkwGwwXLlSVHxpvfmyOxjCOf0ec=
github.com/drakkan/ftpserverlib v0.0.0-20230820193955-e7243edeb89b/go.mod h1:dI9/yw/KfJ0g4wmRK8ZukUfqakLr6ZTf9VDydKoLy90=
github.com/drakkan/webdav v0.0.0-20230227175313-32996838bcd8 h1:tdkLkSKtYd3WSDsZXGJDKsakiNstLQJPN5HjnqCkf2c=
github.com/drakkan/webdav v0.0.0-20230227175313-32996838bcd8/go.mod h1:zOVb1QDhwwqWn2L2qZ0U3swMSO4GTSNyIwXCGO/UGWE=
github.com/drakkan/ftp v0.0.0-20240210102745-f1ffc43f78d2 h1:ufiGMPFBjndWSQOst9FNP11IuMqPblI2NXbpRMUWNhk=
github.com/drakkan/ftp v0.0.0-20240210102745-f1ffc43f78d2/go.mod h1:4p8lUl4vQ80L598CygL+3IFtm+3nggvvW/palOlViwE=
github.com/drakkan/ftpserverlib v0.0.0-20240212100826-a241365cb085 h1:LAKYR9z9USKeyEQK91sRWldmMOjEHLOt2NuLDx+x1UQ=
github.com/drakkan/ftpserverlib v0.0.0-20240212100826-a241365cb085/go.mod h1:9rZ27KBV3xlXmjIfd6HynND28tse8ShZJ/NQkprCKno=
github.com/drakkan/sftp v0.0.0-20240214104840-fbb0b8bdb30c h1:usPo/2W6Dj2rugQiEml0pwmUfY/wUgW6nLGl+q98c5k=
github.com/drakkan/sftp v0.0.0-20240214104840-fbb0b8bdb30c/go.mod h1:KMKI0t3T6hfA+lTR/ssZdunHo+uwq7ghoN09/FSu3DY=
github.com/drakkan/webdav v0.0.0-20240212101318-94e905cb9adb h1:BLT+1m0U57PevUtACnTUoMErsLJyK2ydeNiVuX8R3Lk=
github.com/drakkan/webdav v0.0.0-20240212101318-94e905cb9adb/go.mod h1:zOVb1QDhwwqWn2L2qZ0U3swMSO4GTSNyIwXCGO/UGWE=
github.com/eikenb/pipeat v0.0.0-20210730190139-06b3e6902001 h1:/ZshrfQzayqRSBDodmp3rhNCHJCff+utvgBuWRbiqu4=
github.com/eikenb/pipeat v0.0.0-20210730190139-06b3e6902001/go.mod h1:kltMsfRMTHSFdMbK66XdS8mfMW77+FZA1fGY1xYMF84=
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
@ -216,11 +218,15 @@ github.com/googleapis/enterprise-certificate-proxy v0.3.2 h1:Vie5ybvEvT75RniqhfF
github.com/googleapis/enterprise-certificate-proxy v0.3.2/go.mod h1:VLSiSSBs/ksPL8kq3OBOQ6WRI2QnaFynd1DCjZ62+V0=
github.com/googleapis/gax-go/v2 v2.12.0 h1:A+gCJKdRfqXkr+BIRGtZLibNXf0m1f9E4HG56etFpas=
github.com/googleapis/gax-go/v2 v2.12.0/go.mod h1:y+aIqrI5eb1YGMVJfuV3185Ts/D7qKpsEkdD5+I6QGU=
github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA=
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ=
github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48=
github.com/hashicorp/go-hclog v0.9.2/go.mod h1:5CU+agLiy3J7N7QjHK5d05KxGsuXiQLrjA0H7acj2lQ=
github.com/hashicorp/go-hclog v1.6.2 h1:NOtoftovWkDheyUM/8JW3QMiXyxJK3uHRK7wV04nD2I=
github.com/hashicorp/go-hclog v1.6.2/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M=
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
github.com/hashicorp/go-plugin v1.6.0 h1:wgd4KxHJTVGGqWBq4QPB1i5BZNEx9BR8+OFmHDmTk8A=
github.com/hashicorp/go-plugin v1.6.0/go.mod h1:lBS5MtSSBZk0SHc66KACcjjlU6WzEVP/8pwz68aMkCI=
github.com/hashicorp/go-retryablehttp v0.7.5 h1:bJj+Pj19UZMIweq/iie+1u5YCdGrnxCT9yvm0e+Nd5M=
@ -311,8 +317,6 @@ github.com/pires/go-proxyproto v0.7.0/go.mod h1:Vz/1JPY/OACxWGQNIRY2BeyDmpoaWmEP
github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 h1:KoWmjvw+nsYOo29YJK9vDA65RGE3NrOnUtO7a+RF9HU=
github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8/go.mod h1:HKlIX3XHQyzLZPlr7++PzdhaXEj94dEiJgZDTsxEqUI=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/sftp v1.13.6 h1:JFZT4XbOU7l77xGSpOdW+pwIMqP044IyjXX6FGyEKFo=
github.com/pkg/sftp v1.13.6/go.mod h1:tz1ryNURKu77RL+GuCzmoJYxQczL3wLNNpPWagdg4Qk=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
@ -487,8 +491,9 @@ golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU=
golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y=
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=

View file

@ -297,7 +297,7 @@ func (c *BaseConnection) truncateOpenHandle(fsPath string, size int64) (int64, e
}
// ListDir reads the directory matching virtualPath and returns a list of directory entries
func (c *BaseConnection) ListDir(virtualPath string) ([]os.FileInfo, error) {
func (c *BaseConnection) ListDir(virtualPath string) (*DirListerAt, error) {
if !c.User.HasPerm(dataprovider.PermListItems, virtualPath) {
return nil, c.GetPermissionDeniedError()
}
@ -305,12 +305,17 @@ func (c *BaseConnection) ListDir(virtualPath string) ([]os.FileInfo, error) {
if err != nil {
return nil, err
}
files, err := fs.ReadDir(fsPath)
lister, err := fs.ReadDir(fsPath)
if err != nil {
c.Log(logger.LevelDebug, "error listing directory: %+v", err)
return nil, c.GetFsError(fs, err)
}
return c.User.FilterListDir(files, virtualPath), nil
return &DirListerAt{
virtualPath: virtualPath,
user: &c.User,
info: c.User.GetVirtualFoldersInfo(virtualPath),
lister: lister,
}, nil
}
// CheckParentDirs tries to create the specified directory and any missing parent dirs
@ -511,24 +516,42 @@ func (c *BaseConnection) RemoveDir(virtualPath string) error {
return nil
}
func (c *BaseConnection) doRecursiveRemoveDirEntry(virtualPath string, info os.FileInfo) error {
func (c *BaseConnection) doRecursiveRemoveDirEntry(virtualPath string, info os.FileInfo, recursion int) error {
fs, fsPath, err := c.GetFsAndResolvedPath(virtualPath)
if err != nil {
return err
}
return c.doRecursiveRemove(fs, fsPath, virtualPath, info)
return c.doRecursiveRemove(fs, fsPath, virtualPath, info, recursion)
}
func (c *BaseConnection) doRecursiveRemove(fs vfs.Fs, fsPath, virtualPath string, info os.FileInfo) error {
func (c *BaseConnection) doRecursiveRemove(fs vfs.Fs, fsPath, virtualPath string, info os.FileInfo, recursion int) error {
if info.IsDir() {
entries, err := c.ListDir(virtualPath)
if err != nil {
return fmt.Errorf("unable to get contents for dir %q: %w", virtualPath, err)
if recursion >= util.MaxRecursion {
c.Log(logger.LevelError, "recursive rename failed, recursion too depth: %d", recursion)
return util.ErrRecursionTooDeep
}
for _, fi := range entries {
targetPath := path.Join(virtualPath, fi.Name())
if err := c.doRecursiveRemoveDirEntry(targetPath, fi); err != nil {
return err
recursion++
lister, err := c.ListDir(virtualPath)
if err != nil {
return fmt.Errorf("unable to get lister for dir %q: %w", virtualPath, err)
}
defer lister.Close()
for {
entries, err := lister.Next(vfs.ListerBatchSize)
finished := errors.Is(err, io.EOF)
if err != nil && !finished {
return fmt.Errorf("unable to get content for dir %q: %w", virtualPath, err)
}
for _, fi := range entries {
targetPath := path.Join(virtualPath, fi.Name())
if err := c.doRecursiveRemoveDirEntry(targetPath, fi, recursion); err != nil {
return err
}
}
if finished {
lister.Close()
break
}
}
return c.RemoveDir(virtualPath)
@ -552,7 +575,7 @@ func (c *BaseConnection) RemoveAll(virtualPath string) error {
if err := c.IsRemoveDirAllowed(fs, fsPath, virtualPath); err != nil {
return err
}
return c.doRecursiveRemove(fs, fsPath, virtualPath, fi)
return c.doRecursiveRemove(fs, fsPath, virtualPath, fi, 0)
}
return c.RemoveFile(fs, fsPath, virtualPath, fi)
}
@ -626,43 +649,38 @@ func (c *BaseConnection) copyFile(virtualSourcePath, virtualTargetPath string, s
}
func (c *BaseConnection) doRecursiveCopy(virtualSourcePath, virtualTargetPath string, srcInfo os.FileInfo,
createTargetDir bool,
createTargetDir bool, recursion int,
) error {
if srcInfo.IsDir() {
if recursion >= util.MaxRecursion {
c.Log(logger.LevelError, "recursive copy failed, recursion too depth: %d", recursion)
return util.ErrRecursionTooDeep
}
recursion++
if createTargetDir {
if err := c.CreateDir(virtualTargetPath, false); err != nil {
return fmt.Errorf("unable to create directory %q: %w", virtualTargetPath, err)
}
}
entries, err := c.ListDir(virtualSourcePath)
lister, err := c.ListDir(virtualSourcePath)
if err != nil {
return fmt.Errorf("unable to get contents for dir %q: %w", virtualSourcePath, err)
return fmt.Errorf("unable to get lister for dir %q: %w", virtualSourcePath, err)
}
for _, info := range entries {
sourcePath := path.Join(virtualSourcePath, info.Name())
targetPath := path.Join(virtualTargetPath, info.Name())
targetInfo, err := c.DoStat(targetPath, 1, false)
if err == nil {
if info.IsDir() && targetInfo.IsDir() {
c.Log(logger.LevelDebug, "target copy dir %q already exists", targetPath)
continue
}
defer lister.Close()
for {
entries, err := lister.Next(vfs.ListerBatchSize)
finished := errors.Is(err, io.EOF)
if err != nil && !finished {
return fmt.Errorf("unable to get contents for dir %q: %w", virtualSourcePath, err)
}
if err != nil && !c.IsNotExistError(err) {
if err := c.recursiveCopyEntries(virtualSourcePath, virtualTargetPath, entries, recursion); err != nil {
return err
}
if err := c.checkCopy(info, targetInfo, sourcePath, targetPath); err != nil {
return err
}
if err := c.doRecursiveCopy(sourcePath, targetPath, info, true); err != nil {
if c.IsNotExistError(err) {
c.Log(logger.LevelInfo, "skipping copy for source path %q: %v", sourcePath, err)
continue
}
return err
if finished {
return nil
}
}
return nil
}
if !srcInfo.Mode().IsRegular() {
c.Log(logger.LevelInfo, "skipping copy for non regular file %q", virtualSourcePath)
@ -672,6 +690,34 @@ func (c *BaseConnection) doRecursiveCopy(virtualSourcePath, virtualTargetPath st
return c.copyFile(virtualSourcePath, virtualTargetPath, srcInfo.Size())
}
func (c *BaseConnection) recursiveCopyEntries(virtualSourcePath, virtualTargetPath string, entries []os.FileInfo, recursion int) error {
for _, info := range entries {
sourcePath := path.Join(virtualSourcePath, info.Name())
targetPath := path.Join(virtualTargetPath, info.Name())
targetInfo, err := c.DoStat(targetPath, 1, false)
if err == nil {
if info.IsDir() && targetInfo.IsDir() {
c.Log(logger.LevelDebug, "target copy dir %q already exists", targetPath)
continue
}
}
if err != nil && !c.IsNotExistError(err) {
return err
}
if err := c.checkCopy(info, targetInfo, sourcePath, targetPath); err != nil {
return err
}
if err := c.doRecursiveCopy(sourcePath, targetPath, info, true, recursion); err != nil {
if c.IsNotExistError(err) {
c.Log(logger.LevelInfo, "skipping copy for source path %q: %v", sourcePath, err)
continue
}
return err
}
}
return nil
}
// Copy virtualSourcePath to virtualTargetPath
func (c *BaseConnection) Copy(virtualSourcePath, virtualTargetPath string) error {
copyFromSource := strings.HasSuffix(virtualSourcePath, "/")
@ -717,7 +763,7 @@ func (c *BaseConnection) Copy(virtualSourcePath, virtualTargetPath string) error
defer close(done)
go keepConnectionAlive(c, done, 2*time.Minute)
return c.doRecursiveCopy(virtualSourcePath, destPath, srcInfo, createTargetDir)
return c.doRecursiveCopy(virtualSourcePath, destPath, srcInfo, createTargetDir, 0)
}
// Rename renames (moves) virtualSourcePath to virtualTargetPath
@ -865,7 +911,8 @@ func (c *BaseConnection) doStatInternal(virtualPath string, mode int, checkFileP
convertResult bool,
) (os.FileInfo, error) {
// for some vfs we don't create intermediary folders so we cannot simply check
// if virtualPath is a virtual folder
// if virtualPath is a virtual folder. Allowing stat for hidden virtual folders
// is by purpose.
vfolders := c.User.GetVirtualFoldersInPath(path.Dir(virtualPath))
if _, ok := vfolders[virtualPath]; ok {
return vfs.NewFileInfo(virtualPath, true, 0, time.Unix(0, 0), false), nil
@ -1739,6 +1786,83 @@ func (c *BaseConnection) GetFsAndResolvedPath(virtualPath string) (vfs.Fs, strin
return fs, fsPath, nil
}
// DirListerAt defines a directory lister implementing the ListAt method.
type DirListerAt struct {
virtualPath string
user *dataprovider.User
info []os.FileInfo
mu sync.Mutex
lister vfs.DirLister
}
// Add adds the given os.FileInfo to the internal cache
func (l *DirListerAt) Add(fi os.FileInfo) {
l.mu.Lock()
defer l.mu.Unlock()
l.info = append(l.info, fi)
}
// ListAt implements sftp.ListerAt
func (l *DirListerAt) ListAt(f []os.FileInfo, _ int64) (int, error) {
l.mu.Lock()
defer l.mu.Unlock()
if len(f) == 0 {
return 0, errors.New("invalid ListAt destination, zero size")
}
if len(f) <= len(l.info) {
files := make([]os.FileInfo, 0, len(f))
for idx := len(l.info) - 1; idx >= 0; idx-- {
files = append(files, l.info[idx])
if len(files) == len(f) {
l.info = l.info[:idx]
n := copy(f, files)
return n, nil
}
}
}
limit := len(f) - len(l.info)
files, err := l.Next(limit)
n := copy(f, files)
return n, err
}
// Next reads the directory and returns a slice of up to n FileInfo values.
func (l *DirListerAt) Next(limit int) ([]os.FileInfo, error) {
for {
files, err := l.lister.Next(limit)
if err != nil && !errors.Is(err, io.EOF) {
return files, err
}
files = l.user.FilterListDir(files, l.virtualPath)
if len(l.info) > 0 {
for _, fi := range l.info {
files = util.PrependFileInfo(files, fi)
}
l.info = nil
}
if err != nil || len(files) > 0 {
return files, err
}
}
}
// Close closes the DirListerAt
func (l *DirListerAt) Close() error {
l.mu.Lock()
defer l.mu.Unlock()
return l.lister.Close()
}
func (l *DirListerAt) convertError(err error) error {
if errors.Is(err, io.EOF) {
return nil
}
return err
}
func getPermissionDeniedError(protocol string) error {
switch protocol {
case ProtocolSFTP:

View file

@ -17,10 +17,12 @@ package common
import (
"errors"
"fmt"
"io"
"os"
"path"
"path/filepath"
"runtime"
"strconv"
"testing"
"time"
@ -601,8 +603,10 @@ func TestErrorResolvePath(t *testing.T) {
}
conn := NewBaseConnection("", ProtocolSFTP, "", "", u)
err := conn.doRecursiveRemoveDirEntry("/vpath", nil)
err := conn.doRecursiveRemoveDirEntry("/vpath", nil, 0)
assert.Error(t, err)
err = conn.doRecursiveRemove(nil, "/fspath", "/vpath", vfs.NewFileInfo("vpath", true, 0, time.Now(), false), 2000)
assert.Error(t, err, util.ErrRecursionTooDeep)
err = conn.checkCopy(vfs.NewFileInfo("name", true, 0, time.Unix(0, 0), false), nil, "/source", "/target")
assert.Error(t, err)
sourceFile := filepath.Join(os.TempDir(), "f", "source")
@ -700,26 +704,32 @@ func TestFilePatterns(t *testing.T) {
VirtualFolders: virtualFolders,
}
getFilteredInfo := func(dirContents []os.FileInfo, virtualPath string) []os.FileInfo {
result := user.FilterListDir(dirContents, virtualPath)
result = append(result, user.GetVirtualFoldersInfo(virtualPath)...)
return result
}
dirContents := []os.FileInfo{
vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false),
vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false),
}
// dirContents are modified in place, we need to redefine them each time
filtered := user.FilterListDir(dirContents, "/dir1")
filtered := getFilteredInfo(dirContents, "/dir1")
assert.Len(t, filtered, 5)
dirContents = []os.FileInfo{
vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false),
vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false),
}
filtered = user.FilterListDir(dirContents, "/dir1/vdir1")
filtered = getFilteredInfo(dirContents, "/dir1/vdir1")
assert.Len(t, filtered, 2)
dirContents = []os.FileInfo{
vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false),
vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false),
}
filtered = user.FilterListDir(dirContents, "/dir2/vdir2")
filtered = getFilteredInfo(dirContents, "/dir2/vdir2")
require.Len(t, filtered, 1)
assert.Equal(t, "file1.jpg", filtered[0].Name())
@ -727,7 +737,7 @@ func TestFilePatterns(t *testing.T) {
vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false),
vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false),
}
filtered = user.FilterListDir(dirContents, "/dir2/vdir2/sub")
filtered = getFilteredInfo(dirContents, "/dir2/vdir2/sub")
require.Len(t, filtered, 1)
assert.Equal(t, "file1.jpg", filtered[0].Name())
@ -754,14 +764,14 @@ func TestFilePatterns(t *testing.T) {
vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false),
vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false),
}
filtered = user.FilterListDir(dirContents, "/dir4")
filtered = getFilteredInfo(dirContents, "/dir4")
require.Len(t, filtered, 0)
dirContents = []os.FileInfo{
vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false),
vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false),
}
filtered = user.FilterListDir(dirContents, "/dir4/vdir2/sub")
filtered = getFilteredInfo(dirContents, "/dir4/vdir2/sub")
require.Len(t, filtered, 0)
dirContents = []os.FileInfo{
@ -769,7 +779,7 @@ func TestFilePatterns(t *testing.T) {
vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false),
}
filtered = user.FilterListDir(dirContents, "/dir2")
filtered = getFilteredInfo(dirContents, "/dir2")
assert.Len(t, filtered, 2)
dirContents = []os.FileInfo{
@ -777,7 +787,7 @@ func TestFilePatterns(t *testing.T) {
vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false),
}
filtered = user.FilterListDir(dirContents, "/dir4")
filtered = getFilteredInfo(dirContents, "/dir4")
assert.Len(t, filtered, 0)
dirContents = []os.FileInfo{
@ -785,7 +795,7 @@ func TestFilePatterns(t *testing.T) {
vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false),
}
filtered = user.FilterListDir(dirContents, "/dir4/sub")
filtered = getFilteredInfo(dirContents, "/dir4/sub")
assert.Len(t, filtered, 0)
dirContents = []os.FileInfo{
@ -793,10 +803,10 @@ func TestFilePatterns(t *testing.T) {
vfs.NewFileInfo("vdir3.jpg", false, 123, time.Now(), false),
}
filtered = user.FilterListDir(dirContents, "/dir1")
filtered = getFilteredInfo(dirContents, "/dir1")
assert.Len(t, filtered, 5)
filtered = user.FilterListDir(dirContents, "/dir2")
filtered = getFilteredInfo(dirContents, "/dir2")
if assert.Len(t, filtered, 1) {
assert.True(t, filtered[0].IsDir())
}
@ -806,14 +816,14 @@ func TestFilePatterns(t *testing.T) {
vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false),
vfs.NewFileInfo("vdir3.jpg", false, 123, time.Now(), false),
}
filtered = user.FilterListDir(dirContents, "/dir1")
filtered = getFilteredInfo(dirContents, "/dir1")
assert.Len(t, filtered, 2)
dirContents = []os.FileInfo{
vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false),
vfs.NewFileInfo("vdir3.jpg", false, 123, time.Now(), false),
}
filtered = user.FilterListDir(dirContents, "/dir2")
filtered = getFilteredInfo(dirContents, "/dir2")
if assert.Len(t, filtered, 1) {
assert.False(t, filtered[0].IsDir())
}
@ -824,7 +834,7 @@ func TestFilePatterns(t *testing.T) {
vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false),
vfs.NewFileInfo("vdir3.jpg", false, 123, time.Now(), false),
}
filtered = user.FilterListDir(dirContents, "/dir2")
filtered = getFilteredInfo(dirContents, "/dir2")
if assert.Len(t, filtered, 2) {
assert.False(t, filtered[0].IsDir())
assert.False(t, filtered[1].IsDir())
@ -832,9 +842,9 @@ func TestFilePatterns(t *testing.T) {
user.VirtualFolders = virtualFolders
user.Filters = filters
filtered = user.FilterListDir(nil, "/dir1")
filtered = getFilteredInfo(nil, "/dir1")
assert.Len(t, filtered, 3)
filtered = user.FilterListDir(nil, "/dir2")
filtered = getFilteredInfo(nil, "/dir2")
assert.Len(t, filtered, 1)
dirContents = []os.FileInfo{
@ -843,7 +853,7 @@ func TestFilePatterns(t *testing.T) {
vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false),
vfs.NewFileInfo("vdir3.jpg", false, 456, time.Now(), false),
}
filtered = user.FilterListDir(dirContents, "/dir2")
filtered = getFilteredInfo(dirContents, "/dir2")
assert.Len(t, filtered, 2)
user = dataprovider.User{
@ -866,7 +876,7 @@ func TestFilePatterns(t *testing.T) {
vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false),
vfs.NewFileInfo("vdir3.jpg", false, 456, time.Now(), false),
}
filtered = user.FilterListDir(dirContents, "/dir3")
filtered = getFilteredInfo(dirContents, "/dir3")
assert.Len(t, filtered, 0)
dirContents = nil
@ -881,7 +891,7 @@ func TestFilePatterns(t *testing.T) {
dirContents = append(dirContents, vfs.NewFileInfo("ic35.*", false, 123, time.Now(), false))
dirContents = append(dirContents, vfs.NewFileInfo("file.jpg", false, 123, time.Now(), false))
filtered = user.FilterListDir(dirContents, "/dir3")
filtered = getFilteredInfo(dirContents, "/dir3")
require.Len(t, filtered, 1)
assert.Equal(t, "ic35", filtered[0].Name())
@ -890,7 +900,7 @@ func TestFilePatterns(t *testing.T) {
vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false),
vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false),
}
filtered = user.FilterListDir(dirContents, "/dir3/ic36")
filtered = getFilteredInfo(dirContents, "/dir3/ic36")
require.Len(t, filtered, 0)
dirContents = []os.FileInfo{
@ -898,7 +908,7 @@ func TestFilePatterns(t *testing.T) {
vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false),
vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false),
}
filtered = user.FilterListDir(dirContents, "/dir3/ic35")
filtered = getFilteredInfo(dirContents, "/dir3/ic35")
require.Len(t, filtered, 3)
dirContents = []os.FileInfo{
@ -906,7 +916,7 @@ func TestFilePatterns(t *testing.T) {
vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false),
vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false),
}
filtered = user.FilterListDir(dirContents, "/dir3/ic35/sub")
filtered = getFilteredInfo(dirContents, "/dir3/ic35/sub")
require.Len(t, filtered, 3)
res, _ = user.IsFileAllowed("/dir3/file.txt")
@ -930,7 +940,7 @@ func TestFilePatterns(t *testing.T) {
vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false),
vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false),
}
filtered = user.FilterListDir(dirContents, "/dir3/ic35/sub")
filtered = getFilteredInfo(dirContents, "/dir3/ic35/sub")
require.Len(t, filtered, 3)
user.Filters.FilePatterns = append(user.Filters.FilePatterns, sdk.PatternsFilter{
@ -949,7 +959,7 @@ func TestFilePatterns(t *testing.T) {
vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false),
vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false),
}
filtered = user.FilterListDir(dirContents, "/dir3/ic35/sub1")
filtered = getFilteredInfo(dirContents, "/dir3/ic35/sub1")
require.Len(t, filtered, 3)
dirContents = []os.FileInfo{
@ -957,7 +967,7 @@ func TestFilePatterns(t *testing.T) {
vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false),
vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false),
}
filtered = user.FilterListDir(dirContents, "/dir3/ic35/sub2")
filtered = getFilteredInfo(dirContents, "/dir3/ic35/sub2")
require.Len(t, filtered, 2)
dirContents = []os.FileInfo{
@ -965,7 +975,7 @@ func TestFilePatterns(t *testing.T) {
vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false),
vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false),
}
filtered = user.FilterListDir(dirContents, "/dir3/ic35/sub2/sub1")
filtered = getFilteredInfo(dirContents, "/dir3/ic35/sub2/sub1")
require.Len(t, filtered, 2)
res, _ = user.IsFileAllowed("/dir3/ic35/file.jpg")
@ -1023,7 +1033,7 @@ func TestFilePatterns(t *testing.T) {
vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false),
vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false),
}
filtered = user.FilterListDir(dirContents, "/dir3/ic35")
filtered = getFilteredInfo(dirContents, "/dir3/ic35")
require.Len(t, filtered, 1)
dirContents = []os.FileInfo{
@ -1031,6 +1041,116 @@ func TestFilePatterns(t *testing.T) {
vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false),
vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false),
}
filtered = user.FilterListDir(dirContents, "/dir3/ic35/abc")
filtered = getFilteredInfo(dirContents, "/dir3/ic35/abc")
require.Len(t, filtered, 1)
}
func TestListerAt(t *testing.T) {
dir := t.TempDir()
user := dataprovider.User{
BaseUser: sdk.BaseUser{
Username: "u",
Password: "p",
HomeDir: dir,
Status: 1,
Permissions: map[string][]string{
"/": {"*"},
},
},
}
conn := NewBaseConnection(xid.New().String(), ProtocolSFTP, "", "", user)
lister, err := conn.ListDir("/")
require.NoError(t, err)
files, err := lister.Next(1)
require.ErrorIs(t, err, io.EOF)
require.Len(t, files, 0)
err = lister.Close()
require.NoError(t, err)
conn.User.VirtualFolders = []vfs.VirtualFolder{
{
VirtualPath: "p1",
},
{
VirtualPath: "p2",
},
{
VirtualPath: "p3",
},
}
lister, err = conn.ListDir("/")
require.NoError(t, err)
files, err = lister.Next(2)
// virtual directories exceeds the limit
require.ErrorIs(t, err, io.EOF)
require.Len(t, files, 3)
files, err = lister.Next(2)
require.ErrorIs(t, err, io.EOF)
require.Len(t, files, 0)
_, err = lister.Next(-1)
require.ErrorContains(t, err, "invalid limit")
err = lister.Close()
require.NoError(t, err)
lister, err = conn.ListDir("/")
require.NoError(t, err)
_, err = lister.ListAt(nil, 0)
require.ErrorContains(t, err, "zero size")
err = lister.Close()
require.NoError(t, err)
for i := 0; i < 100; i++ {
f, err := os.Create(filepath.Join(dir, strconv.Itoa(i)))
require.NoError(t, err)
err = f.Close()
require.NoError(t, err)
}
lister, err = conn.ListDir("/")
require.NoError(t, err)
files = make([]os.FileInfo, 18)
n, err := lister.ListAt(files, 0)
require.NoError(t, err)
require.Equal(t, 18, n)
n, err = lister.ListAt(files, 0)
require.NoError(t, err)
require.Equal(t, 18, n)
files = make([]os.FileInfo, 100)
n, err = lister.ListAt(files, 0)
require.NoError(t, err)
require.Equal(t, 64+3, n)
n, err = lister.ListAt(files, 0)
require.ErrorIs(t, err, io.EOF)
require.Equal(t, 0, n)
n, err = lister.ListAt(files, 0)
require.ErrorIs(t, err, io.EOF)
require.Equal(t, 0, n)
err = lister.Close()
require.NoError(t, err)
n, err = lister.ListAt(files, 0)
require.Error(t, err)
require.NotErrorIs(t, err, io.EOF)
require.Equal(t, 0, n)
lister, err = conn.ListDir("/")
require.NoError(t, err)
lister.Add(vfs.NewFileInfo("..", true, 0, time.Unix(0, 0), false))
lister.Add(vfs.NewFileInfo(".", true, 0, time.Unix(0, 0), false))
files = make([]os.FileInfo, 1)
n, err = lister.ListAt(files, 0)
require.NoError(t, err)
require.Equal(t, 1, n)
assert.Equal(t, ".", files[0].Name())
files = make([]os.FileInfo, 2)
n, err = lister.ListAt(files, 0)
require.NoError(t, err)
require.Equal(t, 2, n)
assert.Equal(t, "..", files[0].Name())
assert.Equal(t, "p3", files[1].Name())
files = make([]os.FileInfo, 200)
n, err = lister.ListAt(files, 0)
require.NoError(t, err)
require.Equal(t, 102, n)
assert.Equal(t, "p2", files[0].Name())
assert.Equal(t, "p1", files[1].Name())
err = lister.Close()
require.NoError(t, err)
}

View file

@ -18,7 +18,9 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"os"
@ -37,6 +39,7 @@ import (
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/smtp"
"github.com/drakkan/sftpgo/v2/internal/util"
"github.com/drakkan/sftpgo/v2/internal/vfs"
)
// RetentionCheckNotification defines the supported notification methods for a retention check result
@ -226,8 +229,17 @@ func (c *RetentionCheck) removeFile(virtualPath string, info os.FileInfo) error
return c.conn.RemoveFile(fs, fsPath, virtualPath, info)
}
func (c *RetentionCheck) cleanupFolder(folderPath string) error {
deleteFilesPerms := []string{dataprovider.PermDelete, dataprovider.PermDeleteFiles}
func (c *RetentionCheck) hasCleanupPerms(folderPath string) bool {
if !c.conn.User.HasPerm(dataprovider.PermListItems, folderPath) {
return false
}
if !c.conn.User.HasAnyPerm([]string{dataprovider.PermDelete, dataprovider.PermDeleteFiles}, folderPath) {
return false
}
return true
}
func (c *RetentionCheck) cleanupFolder(folderPath string, recursion int) error {
startTime := time.Now()
result := folderRetentionCheckResult{
Path: folderPath,
@ -235,7 +247,15 @@ func (c *RetentionCheck) cleanupFolder(folderPath string) error {
defer func() {
c.results = append(c.results, result)
}()
if !c.conn.User.HasPerm(dataprovider.PermListItems, folderPath) || !c.conn.User.HasAnyPerm(deleteFilesPerms, folderPath) {
if recursion >= util.MaxRecursion {
result.Elapsed = time.Since(startTime)
result.Info = "data retention check skipped: recursion too deep"
c.conn.Log(logger.LevelError, "data retention check skipped, recursion too depth for %q: %d",
folderPath, recursion)
return util.ErrRecursionTooDeep
}
recursion++
if !c.hasCleanupPerms(folderPath) {
result.Elapsed = time.Since(startTime)
result.Info = "data retention check skipped: no permissions"
c.conn.Log(logger.LevelInfo, "user %q does not have permissions to check retention on %q, retention check skipped",
@ -259,7 +279,7 @@ func (c *RetentionCheck) cleanupFolder(folderPath string) error {
}
c.conn.Log(logger.LevelDebug, "start retention check for folder %q, retention: %v hours, delete empty dirs? %v, ignore user perms? %v",
folderPath, folderRetention.Retention, folderRetention.DeleteEmptyDirs, folderRetention.IgnoreUserPermissions)
files, err := c.conn.ListDir(folderPath)
lister, err := c.conn.ListDir(folderPath)
if err != nil {
result.Elapsed = time.Since(startTime)
if err == c.conn.GetNotExistError() {
@ -267,40 +287,54 @@ func (c *RetentionCheck) cleanupFolder(folderPath string) error {
c.conn.Log(logger.LevelDebug, "folder %q does not exist, retention check skipped", folderPath)
return nil
}
result.Error = fmt.Sprintf("unable to list directory %q", folderPath)
result.Error = fmt.Sprintf("unable to get lister for directory %q", folderPath)
c.conn.Log(logger.LevelError, result.Error)
return err
}
for _, info := range files {
virtualPath := path.Join(folderPath, info.Name())
if info.IsDir() {
if err := c.cleanupFolder(virtualPath); err != nil {
result.Elapsed = time.Since(startTime)
result.Error = fmt.Sprintf("unable to check folder: %v", err)
c.conn.Log(logger.LevelError, "unable to cleanup folder %q: %v", virtualPath, err)
return err
}
} else {
retentionTime := info.ModTime().Add(time.Duration(folderRetention.Retention) * time.Hour)
if retentionTime.Before(time.Now()) {
if err := c.removeFile(virtualPath, info); err != nil {
defer lister.Close()
for {
files, err := lister.Next(vfs.ListerBatchSize)
finished := errors.Is(err, io.EOF)
if err := lister.convertError(err); err != nil {
result.Elapsed = time.Since(startTime)
result.Error = fmt.Sprintf("unable to list directory %q", folderPath)
c.conn.Log(logger.LevelError, "unable to list dir %q: %v", folderPath, err)
return err
}
for _, info := range files {
virtualPath := path.Join(folderPath, info.Name())
if info.IsDir() {
if err := c.cleanupFolder(virtualPath, recursion); err != nil {
result.Elapsed = time.Since(startTime)
result.Error = fmt.Sprintf("unable to remove file %q: %v", virtualPath, err)
c.conn.Log(logger.LevelError, "unable to remove file %q, retention %v: %v",
virtualPath, retentionTime, err)
result.Error = fmt.Sprintf("unable to check folder: %v", err)
c.conn.Log(logger.LevelError, "unable to cleanup folder %q: %v", virtualPath, err)
return err
}
c.conn.Log(logger.LevelDebug, "removed file %q, modification time: %v, retention: %v hours, retention time: %v",
virtualPath, info.ModTime(), folderRetention.Retention, retentionTime)
result.DeletedFiles++
result.DeletedSize += info.Size()
} else {
retentionTime := info.ModTime().Add(time.Duration(folderRetention.Retention) * time.Hour)
if retentionTime.Before(time.Now()) {
if err := c.removeFile(virtualPath, info); err != nil {
result.Elapsed = time.Since(startTime)
result.Error = fmt.Sprintf("unable to remove file %q: %v", virtualPath, err)
c.conn.Log(logger.LevelError, "unable to remove file %q, retention %v: %v",
virtualPath, retentionTime, err)
return err
}
c.conn.Log(logger.LevelDebug, "removed file %q, modification time: %v, retention: %v hours, retention time: %v",
virtualPath, info.ModTime(), folderRetention.Retention, retentionTime)
result.DeletedFiles++
result.DeletedSize += info.Size()
}
}
}
if finished {
break
}
}
if folderRetention.DeleteEmptyDirs {
c.checkEmptyDirRemoval(folderPath)
}
lister.Close()
c.checkEmptyDirRemoval(folderPath, folderRetention.DeleteEmptyDirs)
result.Elapsed = time.Since(startTime)
c.conn.Log(logger.LevelDebug, "retention check completed for folder %q, deleted files: %v, deleted size: %v bytes",
folderPath, result.DeletedFiles, result.DeletedSize)
@ -308,8 +342,8 @@ func (c *RetentionCheck) cleanupFolder(folderPath string) error {
return nil
}
func (c *RetentionCheck) checkEmptyDirRemoval(folderPath string) {
if folderPath == "/" {
func (c *RetentionCheck) checkEmptyDirRemoval(folderPath string, checkVal bool) {
if folderPath == "/" || !checkVal {
return
}
for _, folder := range c.Folders {
@ -322,10 +356,14 @@ func (c *RetentionCheck) checkEmptyDirRemoval(folderPath string) {
dataprovider.PermDeleteDirs,
}, path.Dir(folderPath),
) {
files, err := c.conn.ListDir(folderPath)
if err == nil && len(files) == 0 {
err = c.conn.RemoveDir(folderPath)
c.conn.Log(logger.LevelDebug, "tried to remove empty dir %q, error: %v", folderPath, err)
lister, err := c.conn.ListDir(folderPath)
if err == nil {
files, err := lister.Next(1)
lister.Close()
if len(files) == 0 && errors.Is(err, io.EOF) {
err = c.conn.RemoveDir(folderPath)
c.conn.Log(logger.LevelDebug, "tried to remove empty dir %q, error: %v", folderPath, err)
}
}
}
}
@ -339,7 +377,7 @@ func (c *RetentionCheck) Start() error {
startTime := time.Now()
for _, folder := range c.Folders {
if folder.Retention > 0 {
if err := c.cleanupFolder(folder.Path); err != nil {
if err := c.cleanupFolder(folder.Path, 0); err != nil {
c.conn.Log(logger.LevelError, "retention check failed, unable to cleanup folder %q", folder.Path)
c.sendNotifications(time.Since(startTime), err)
return err

View file

@ -28,6 +28,7 @@ import (
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/smtp"
"github.com/drakkan/sftpgo/v2/internal/util"
)
func TestRetentionValidation(t *testing.T) {
@ -272,7 +273,9 @@ func TestRetentionPermissionsAndGetFolder(t *testing.T) {
conn.SetProtocol(ProtocolDataRetention)
conn.ID = fmt.Sprintf("data_retention_%v", user.Username)
check.conn = conn
assert.False(t, check.hasCleanupPerms(check.Folders[2].Path))
check.updateUserPermissions()
assert.True(t, check.hasCleanupPerms(check.Folders[2].Path))
assert.Equal(t, []string{dataprovider.PermListItems, dataprovider.PermDelete}, conn.User.Permissions["/"])
assert.Equal(t, []string{dataprovider.PermListItems}, conn.User.Permissions["/dir1"])
assert.Equal(t, []string{dataprovider.PermAny}, conn.User.Permissions["/dir2"])
@ -391,8 +394,11 @@ func TestCleanupErrors(t *testing.T) {
err := check.removeFile("missing file", nil)
assert.Error(t, err)
err = check.cleanupFolder("/")
err = check.cleanupFolder("/", 0)
assert.Error(t, err)
err = check.cleanupFolder("/", 1000)
assert.ErrorIs(t, err, util.ErrRecursionTooDeep)
assert.True(t, RetentionChecks.remove(user.Username))
}

View file

@ -988,11 +988,16 @@ func getFileWriter(conn *BaseConnection, virtualPath string, expectedSize int64)
return w, numFiles, truncatedSize, cancelFn, nil
}
func addZipEntry(wr *zipWriterWrapper, conn *BaseConnection, entryPath, baseDir string) error {
func addZipEntry(wr *zipWriterWrapper, conn *BaseConnection, entryPath, baseDir string, recursion int) error {
if entryPath == wr.Name {
// skip the archive itself
return nil
}
if recursion >= util.MaxRecursion {
eventManagerLog(logger.LevelError, "unable to add zip entry %q, recursion too deep: %v", entryPath, recursion)
return util.ErrRecursionTooDeep
}
recursion++
info, err := conn.DoStat(entryPath, 1, false)
if err != nil {
eventManagerLog(logger.LevelError, "unable to add zip entry %q, stat error: %v", entryPath, err)
@ -1018,25 +1023,42 @@ func addZipEntry(wr *zipWriterWrapper, conn *BaseConnection, entryPath, baseDir
eventManagerLog(logger.LevelError, "unable to create zip entry %q: %v", entryPath, err)
return fmt.Errorf("unable to create zip entry %q: %w", entryPath, err)
}
contents, err := conn.ListDir(entryPath)
lister, err := conn.ListDir(entryPath)
if err != nil {
eventManagerLog(logger.LevelError, "unable to add zip entry %q, read dir error: %v", entryPath, err)
eventManagerLog(logger.LevelError, "unable to add zip entry %q, get dir lister error: %v", entryPath, err)
return fmt.Errorf("unable to add zip entry %q: %w", entryPath, err)
}
for _, info := range contents {
fullPath := util.CleanPath(path.Join(entryPath, info.Name()))
if err := addZipEntry(wr, conn, fullPath, baseDir); err != nil {
eventManagerLog(logger.LevelError, "unable to add zip entry: %v", err)
return err
defer lister.Close()
for {
contents, err := lister.Next(vfs.ListerBatchSize)
finished := errors.Is(err, io.EOF)
if err := lister.convertError(err); err != nil {
eventManagerLog(logger.LevelError, "unable to add zip entry %q, read dir error: %v", entryPath, err)
return fmt.Errorf("unable to add zip entry %q: %w", entryPath, err)
}
for _, info := range contents {
fullPath := util.CleanPath(path.Join(entryPath, info.Name()))
if err := addZipEntry(wr, conn, fullPath, baseDir, recursion); err != nil {
eventManagerLog(logger.LevelError, "unable to add zip entry: %v", err)
return err
}
}
if finished {
return nil
}
}
return nil
}
if !info.Mode().IsRegular() {
// we only allow regular files
eventManagerLog(logger.LevelInfo, "skipping zip entry for non regular file %q", entryPath)
return nil
}
return addFileToZip(wr, conn, entryPath, entryName, info.ModTime())
}
func addFileToZip(wr *zipWriterWrapper, conn *BaseConnection, entryPath, entryName string, modTime time.Time) error {
reader, cancelFn, err := getFileReader(conn, entryPath)
if err != nil {
eventManagerLog(logger.LevelError, "unable to add zip entry %q, cannot open file: %v", entryPath, err)
@ -1048,7 +1070,7 @@ func addZipEntry(wr *zipWriterWrapper, conn *BaseConnection, entryPath, baseDir
f, err := wr.Writer.CreateHeader(&zip.FileHeader{
Name: entryName,
Method: zip.Deflate,
Modified: info.ModTime(),
Modified: modTime,
})
if err != nil {
eventManagerLog(logger.LevelError, "unable to create zip entry %q: %v", entryPath, err)
@ -1890,18 +1912,28 @@ func getArchiveBaseDir(paths []string) string {
func getSizeForPath(conn *BaseConnection, p string, info os.FileInfo) (int64, error) {
if info.IsDir() {
var dirSize int64
entries, err := conn.ListDir(p)
lister, err := conn.ListDir(p)
if err != nil {
return 0, err
}
for _, entry := range entries {
size, err := getSizeForPath(conn, path.Join(p, entry.Name()), entry)
if err != nil {
defer lister.Close()
for {
entries, err := lister.Next(vfs.ListerBatchSize)
finished := errors.Is(err, io.EOF)
if err != nil && !finished {
return 0, err
}
dirSize += size
for _, entry := range entries {
size, err := getSizeForPath(conn, path.Join(p, entry.Name()), entry)
if err != nil {
return 0, err
}
dirSize += size
}
if finished {
return dirSize, nil
}
}
return dirSize, nil
}
if info.Mode().IsRegular() {
return info.Size(), nil
@ -1978,7 +2010,7 @@ func executeCompressFsActionForUser(c dataprovider.EventActionFsCompress, replac
}
startTime := time.Now()
for _, item := range paths {
if err := addZipEntry(zipWriter, conn, item, baseDir); err != nil {
if err := addZipEntry(zipWriter, conn, item, baseDir, 0); err != nil {
closeWriterAndUpdateQuota(writer, conn, name, "", numFiles, truncatedSize, err, operationUpload, startTime) //nolint:errcheck
return err
}

View file

@ -1835,7 +1835,7 @@ func TestFilesystemActionErrors(t *testing.T) {
Writer: zip.NewWriter(bytes.NewBuffer(nil)),
Entries: map[string]bool{},
}
err = addZipEntry(wr, conn, "/adir/sub/f.dat", "/adir/sub/sub")
err = addZipEntry(wr, conn, "/adir/sub/f.dat", "/adir/sub/sub", 0)
assert.Error(t, err)
assert.Contains(t, getErrorString(err), "is outside base dir")
}

View file

@ -131,9 +131,9 @@ func getDefenderHostQuery() string {
sqlTableDefenderHosts, sqlPlaceholders[0], sqlPlaceholders[1])
}
func getDefenderEventsQuery(hostIDS []int64) string {
func getDefenderEventsQuery(hostIDs []int64) string {
var sb strings.Builder
for _, hID := range hostIDS {
for _, hID := range hostIDs {
if sb.Len() == 0 {
sb.WriteString("(")
} else {

View file

@ -676,8 +676,7 @@ func (u *User) GetVirtualFoldersInPath(virtualPath string) map[string]bool {
result := make(map[string]bool)
for idx := range u.VirtualFolders {
v := &u.VirtualFolders[idx]
dirsForPath := util.GetDirsForVirtualPath(v.VirtualPath)
dirsForPath := util.GetDirsForVirtualPath(u.VirtualFolders[idx].VirtualPath)
for index := range dirsForPath {
d := dirsForPath[index]
if d == "/" {
@ -716,13 +715,34 @@ func (u *User) hasVirtualDirs() bool {
return numFolders > 0
}
// FilterListDir adds virtual folders and remove hidden items from the given files list
// GetVirtualFoldersInfo returns []os.FileInfo for virtual folders
func (u *User) GetVirtualFoldersInfo(virtualPath string) []os.FileInfo {
filter := u.getPatternsFilterForPath(virtualPath)
if !u.hasVirtualDirs() && filter.DenyPolicy != sdk.DenyPolicyHide {
return nil
}
vdirs := u.GetVirtualFoldersInPath(virtualPath)
result := make([]os.FileInfo, 0, len(vdirs))
for dir := range u.GetVirtualFoldersInPath(virtualPath) {
dirName := path.Base(dir)
if filter.DenyPolicy == sdk.DenyPolicyHide {
if !filter.CheckAllowed(dirName) {
continue
}
}
result = append(result, vfs.NewFileInfo(dirName, true, 0, time.Unix(0, 0), false))
}
return result
}
// FilterListDir removes hidden items from the given files list
func (u *User) FilterListDir(dirContents []os.FileInfo, virtualPath string) []os.FileInfo {
filter := u.getPatternsFilterForPath(virtualPath)
if !u.hasVirtualDirs() && filter.DenyPolicy != sdk.DenyPolicyHide {
return dirContents
}
vdirs := make(map[string]bool)
for dir := range u.GetVirtualFoldersInPath(virtualPath) {
dirName := path.Base(dir)
@ -735,36 +755,24 @@ func (u *User) FilterListDir(dirContents []os.FileInfo, virtualPath string) []os
}
validIdx := 0
for index, fi := range dirContents {
for dir := range vdirs {
if fi.Name() == dir {
if !fi.IsDir() {
fi = vfs.NewFileInfo(dir, true, 0, time.Unix(0, 0), false)
dirContents[index] = fi
for idx := range dirContents {
fi := dirContents[idx]
if fi.Name() != "." && fi.Name() != ".." {
if _, ok := vdirs[fi.Name()]; ok {
continue
}
if filter.DenyPolicy == sdk.DenyPolicyHide {
if !filter.CheckAllowed(fi.Name()) {
continue
}
delete(vdirs, dir)
}
}
if filter.DenyPolicy == sdk.DenyPolicyHide {
if filter.CheckAllowed(fi.Name()) {
dirContents[validIdx] = fi
validIdx++
}
}
dirContents[validIdx] = fi
validIdx++
}
if filter.DenyPolicy == sdk.DenyPolicyHide {
for idx := validIdx; idx < len(dirContents); idx++ {
dirContents[idx] = nil
}
dirContents = dirContents[:validIdx]
}
for dir := range vdirs {
fi := vfs.NewFileInfo(dir, true, 0, time.Unix(0, 0), false)
dirContents = append(dirContents, fi)
}
return dirContents
return dirContents[:validIdx]
}
// IsMappedPath returns true if the specified filesystem path has a virtual folder mapping.

View file

@ -2544,14 +2544,14 @@ func TestRename(t *testing.T) {
assert.NoError(t, err)
err = client.MakeDir(path.Join(otherDir, testDir))
assert.NoError(t, err)
code, response, err := client.SendCustomCommand(fmt.Sprintf("SITE CHMOD 0001 %v", otherDir))
code, response, err := client.SendCommand(fmt.Sprintf("SITE CHMOD 0001 %v", otherDir))
assert.NoError(t, err)
assert.Equal(t, ftp.StatusCommandOK, code)
assert.Equal(t, "SITE CHMOD command successful", response)
err = client.Rename(testDir, path.Join(otherDir, testDir))
assert.Error(t, err)
code, response, err = client.SendCustomCommand(fmt.Sprintf("SITE CHMOD 755 %v", otherDir))
code, response, err = client.SendCommand(fmt.Sprintf("SITE CHMOD 755 %v", otherDir))
assert.NoError(t, err)
assert.Equal(t, ftp.StatusCommandOK, code)
assert.Equal(t, "SITE CHMOD command successful", response)
@ -2611,7 +2611,7 @@ func TestSymlink(t *testing.T) {
assert.NoError(t, err)
err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0)
assert.NoError(t, err)
code, _, err := client.SendCustomCommand(fmt.Sprintf("SITE SYMLINK %v %v", testFileName, testFileName+".link"))
code, _, err := client.SendCommand(fmt.Sprintf("SITE SYMLINK %v %v", testFileName, testFileName+".link"))
assert.NoError(t, err)
assert.Equal(t, ftp.StatusCommandOK, code)
@ -2622,15 +2622,15 @@ func TestSymlink(t *testing.T) {
assert.NoError(t, err)
err = client.MakeDir(path.Join(otherDir, testDir))
assert.NoError(t, err)
code, response, err := client.SendCustomCommand(fmt.Sprintf("SITE CHMOD 0001 %v", otherDir))
code, response, err := client.SendCommand(fmt.Sprintf("SITE CHMOD 0001 %v", otherDir))
assert.NoError(t, err)
assert.Equal(t, ftp.StatusCommandOK, code)
assert.Equal(t, "SITE CHMOD command successful", response)
code, _, err = client.SendCustomCommand(fmt.Sprintf("SITE SYMLINK %v %v", testDir, path.Join(otherDir, testDir)))
code, _, err = client.SendCommand(fmt.Sprintf("SITE SYMLINK %v %v", testDir, path.Join(otherDir, testDir)))
assert.NoError(t, err)
assert.Equal(t, ftp.StatusFileUnavailable, code)
code, response, err = client.SendCustomCommand(fmt.Sprintf("SITE CHMOD 755 %v", otherDir))
code, response, err = client.SendCommand(fmt.Sprintf("SITE CHMOD 755 %v", otherDir))
assert.NoError(t, err)
assert.Equal(t, ftp.StatusCommandOK, code)
assert.Equal(t, "SITE CHMOD command successful", response)
@ -2860,17 +2860,17 @@ func TestAllocateAvailable(t *testing.T) {
assert.NoError(t, err)
client, err := getFTPClient(user, false, nil)
if assert.NoError(t, err) {
code, response, err := client.SendCustomCommand("allo 2000000")
code, response, err := client.SendCommand("allo 2000000")
assert.NoError(t, err)
assert.Equal(t, ftp.StatusCommandOK, code)
assert.Equal(t, "Done !", response)
code, response, err = client.SendCustomCommand("AVBL /vdir")
code, response, err = client.SendCommand("AVBL /vdir")
assert.NoError(t, err)
assert.Equal(t, ftp.StatusFile, code)
assert.Equal(t, "110", response)
code, _, err = client.SendCustomCommand("AVBL")
code, _, err = client.SendCommand("AVBL")
assert.NoError(t, err)
assert.Equal(t, ftp.StatusFile, code)
@ -2886,7 +2886,7 @@ func TestAllocateAvailable(t *testing.T) {
testFileSize := user.QuotaSize - 1
err = createTestFile(testFilePath, testFileSize)
assert.NoError(t, err)
code, response, err := client.SendCustomCommand("allo 1000")
code, response, err := client.SendCommand("allo 1000")
assert.NoError(t, err)
assert.Equal(t, ftp.StatusCommandOK, code)
assert.Equal(t, "Done !", response)
@ -2894,7 +2894,7 @@ func TestAllocateAvailable(t *testing.T) {
err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0)
assert.NoError(t, err)
code, response, err = client.SendCustomCommand("AVBL")
code, response, err = client.SendCommand("AVBL")
assert.NoError(t, err)
assert.Equal(t, ftp.StatusFile, code)
assert.Equal(t, "1", response)
@ -2909,7 +2909,7 @@ func TestAllocateAvailable(t *testing.T) {
assert.NoError(t, err)
client, err = getFTPClient(user, false, nil)
if assert.NoError(t, err) {
code, response, err := client.SendCustomCommand("AVBL")
code, response, err := client.SendCommand("AVBL")
assert.NoError(t, err)
assert.Equal(t, ftp.StatusFile, code)
assert.Equal(t, "1", response)
@ -2925,7 +2925,7 @@ func TestAllocateAvailable(t *testing.T) {
assert.NoError(t, err)
client, err = getFTPClient(user, false, nil)
if assert.NoError(t, err) {
code, response, err := client.SendCustomCommand("AVBL")
code, response, err := client.SendCommand("AVBL")
assert.NoError(t, err)
assert.Equal(t, ftp.StatusFile, code)
assert.Equal(t, "5242880", response)
@ -2941,7 +2941,7 @@ func TestAllocateAvailable(t *testing.T) {
assert.NoError(t, err)
client, err = getFTPClient(user, false, nil)
if assert.NoError(t, err) {
code, response, err := client.SendCustomCommand("AVBL")
code, response, err := client.SendCommand("AVBL")
assert.NoError(t, err)
assert.Equal(t, ftp.StatusFile, code)
assert.Equal(t, "5242880", response)
@ -2958,12 +2958,12 @@ func TestAllocateAvailable(t *testing.T) {
assert.NoError(t, err)
client, err = getFTPClient(user, false, nil)
if assert.NoError(t, err) {
code, response, err := client.SendCustomCommand("allo 10000")
code, response, err := client.SendCommand("allo 10000")
assert.NoError(t, err)
assert.Equal(t, ftp.StatusCommandOK, code)
assert.Equal(t, "Done !", response)
code, response, err = client.SendCustomCommand("AVBL")
code, response, err = client.SendCommand("AVBL")
assert.NoError(t, err)
assert.Equal(t, ftp.StatusFile, code)
assert.Equal(t, "100", response)
@ -2977,7 +2977,7 @@ func TestAllocateAvailable(t *testing.T) {
assert.NoError(t, err)
client, err = getFTPClient(user, false, nil)
if assert.NoError(t, err) {
code, response, err := client.SendCustomCommand("AVBL")
code, response, err := client.SendCommand("AVBL")
assert.NoError(t, err)
assert.Equal(t, ftp.StatusFile, code)
assert.Equal(t, "0", response)
@ -2989,7 +2989,7 @@ func TestAllocateAvailable(t *testing.T) {
assert.NoError(t, err)
client, err = getFTPClient(user, false, nil)
if assert.NoError(t, err) {
code, response, err := client.SendCustomCommand("AVBL")
code, response, err := client.SendCommand("AVBL")
assert.NoError(t, err)
assert.Equal(t, ftp.StatusFile, code)
assert.Equal(t, "1", response)
@ -3013,7 +3013,7 @@ func TestAvailableSFTPFs(t *testing.T) {
assert.NoError(t, err)
client, err := getFTPClient(sftpUser, false, nil)
if assert.NoError(t, err) {
code, response, err := client.SendCustomCommand("AVBL /")
code, response, err := client.SendCommand("AVBL /")
assert.NoError(t, err)
assert.Equal(t, ftp.StatusFile, code)
avblSize, err := strconv.ParseInt(response, 10, 64)
@ -3051,7 +3051,7 @@ func TestChtimes(t *testing.T) {
assert.NoError(t, err)
mtime := time.Now().Format("20060102150405")
code, response, err := client.SendCustomCommand(fmt.Sprintf("MFMT %v %v", mtime, testFileName))
code, response, err := client.SendCommand(fmt.Sprintf("MFMT %v %v", mtime, testFileName))
assert.NoError(t, err)
assert.Equal(t, ftp.StatusFile, code)
assert.Equal(t, fmt.Sprintf("Modify=%v; %v", mtime, testFileName), response)
@ -3097,7 +3097,7 @@ func TestChown(t *testing.T) {
assert.NoError(t, err)
err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0)
assert.NoError(t, err)
code, response, err := client.SendCustomCommand(fmt.Sprintf("SITE CHOWN 1000:1000 %v", testFileName))
code, response, err := client.SendCommand(fmt.Sprintf("SITE CHOWN 1000:1000 %v", testFileName))
assert.NoError(t, err)
assert.Equal(t, ftp.StatusFileUnavailable, code)
assert.Equal(t, "Couldn't chown: operation unsupported", response)
@ -3135,7 +3135,7 @@ func TestChmod(t *testing.T) {
err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0)
assert.NoError(t, err)
code, response, err := client.SendCustomCommand(fmt.Sprintf("SITE CHMOD 600 %v", testFileName))
code, response, err := client.SendCommand(fmt.Sprintf("SITE CHMOD 600 %v", testFileName))
assert.NoError(t, err)
assert.Equal(t, ftp.StatusCommandOK, code)
assert.Equal(t, "SITE CHMOD command successful", response)
@ -3182,7 +3182,7 @@ func TestCombineDisabled(t *testing.T) {
err = checkBasicFTP(client)
assert.NoError(t, err)
code, response, err := client.SendCustomCommand("COMB file file.1 file.2")
code, response, err := client.SendCommand("COMB file file.1 file.2")
assert.NoError(t, err)
assert.Equal(t, ftp.StatusNotImplemented, code)
assert.Equal(t, "COMB support is disabled", response)
@ -3208,12 +3208,12 @@ func TestActiveModeDisabled(t *testing.T) {
if assert.NoError(t, err) {
err = checkBasicFTP(client)
assert.NoError(t, err)
code, response, err := client.SendCustomCommand("PORT 10,2,0,2,4,31")
code, response, err := client.SendCommand("PORT 10,2,0,2,4,31")
assert.NoError(t, err)
assert.Equal(t, ftp.StatusNotAvailable, code)
assert.Equal(t, "PORT command is disabled", response)
code, response, err = client.SendCustomCommand("EPRT |1|132.235.1.2|6275|")
code, response, err = client.SendCommand("EPRT |1|132.235.1.2|6275|")
assert.NoError(t, err)
assert.Equal(t, ftp.StatusNotAvailable, code)
assert.Equal(t, "EPRT command is disabled", response)
@ -3224,12 +3224,12 @@ func TestActiveModeDisabled(t *testing.T) {
client, err = getFTPClient(user, false, nil)
if assert.NoError(t, err) {
code, response, err := client.SendCustomCommand("PORT 10,2,0,2,4,31")
code, response, err := client.SendCommand("PORT 10,2,0,2,4,31")
assert.NoError(t, err)
assert.Equal(t, ftp.StatusBadArguments, code)
assert.Equal(t, "Your request does not meet the configured security requirements", response)
code, response, err = client.SendCustomCommand("EPRT |1|132.235.1.2|6275|")
code, response, err = client.SendCommand("EPRT |1|132.235.1.2|6275|")
assert.NoError(t, err)
assert.Equal(t, ftp.StatusBadArguments, code)
assert.Equal(t, "Your request does not meet the configured security requirements", response)
@ -3253,7 +3253,7 @@ func TestSITEDisabled(t *testing.T) {
err = checkBasicFTP(client)
assert.NoError(t, err)
code, response, err := client.SendCustomCommand("SITE CHMOD 600 afile.txt")
code, response, err := client.SendCommand("SITE CHMOD 600 afile.txt")
assert.NoError(t, err)
assert.Equal(t, ftp.StatusBadCommand, code)
assert.Equal(t, "SITE support is disabled", response)
@ -3298,12 +3298,12 @@ func TestHASH(t *testing.T) {
err = f.Close()
assert.NoError(t, err)
code, response, err := client.SendCustomCommand(fmt.Sprintf("XSHA256 %v", testFileName))
code, response, err := client.SendCommand(fmt.Sprintf("XSHA256 %v", testFileName))
assert.NoError(t, err)
assert.Equal(t, ftp.StatusRequestedFileActionOK, code)
assert.Contains(t, response, hash)
code, response, err = client.SendCustomCommand(fmt.Sprintf("HASH %v", testFileName))
code, response, err = client.SendCommand(fmt.Sprintf("HASH %v", testFileName))
assert.NoError(t, err)
assert.Equal(t, ftp.StatusFile, code)
assert.Contains(t, response, hash)
@ -3359,7 +3359,7 @@ func TestCombine(t *testing.T) {
err = ftpUploadFile(testFilePath, testFileName+".2", testFileSize, client, 0)
assert.NoError(t, err)
code, response, err := client.SendCustomCommand(fmt.Sprintf("COMB %v %v %v", testFileName, testFileName+".1", testFileName+".2"))
code, response, err := client.SendCommand(fmt.Sprintf("COMB %v %v %v", testFileName, testFileName+".1", testFileName+".2"))
assert.NoError(t, err)
if user.Username == defaultUsername {
assert.Equal(t, ftp.StatusRequestedFileActionOK, code)

View file

@ -291,7 +291,7 @@ func (c *Connection) Symlink(oldname, newname string) error {
}
// ReadDir implements ClientDriverExtensionFilelist
func (c *Connection) ReadDir(name string) ([]os.FileInfo, error) {
func (c *Connection) ReadDir(name string) (ftpserver.DirLister, error) {
c.UpdateLastActivity()
if c.doWildcardListDir {
@ -302,7 +302,17 @@ func (c *Connection) ReadDir(name string) ([]os.FileInfo, error) {
// - dir*/*.xml is not supported
name = path.Dir(name)
c.clientContext.SetListPath(name)
return c.getListDirWithWildcards(name, baseName)
lister, err := c.ListDir(name)
if err != nil {
return nil, err
}
return &patternDirLister{
DirLister: lister,
pattern: baseName,
lastCommand: c.clientContext.GetLastCommand(),
dirName: name,
connectionPath: c.clientContext.Path(),
}, nil
}
return c.ListDir(name)
@ -506,31 +516,6 @@ func (c *Connection) handleFTPUploadToExistingFile(fs vfs.Fs, flags int, resolve
return t, nil
}
func (c *Connection) getListDirWithWildcards(dirName, pattern string) ([]os.FileInfo, error) {
files, err := c.ListDir(dirName)
if err != nil {
return files, err
}
validIdx := 0
var relativeBase string
if c.clientContext.GetLastCommand() != "NLST" {
relativeBase = getPathRelativeTo(c.clientContext.Path(), dirName)
}
for _, fi := range files {
match, err := path.Match(pattern, fi.Name())
if err != nil {
return files, err
}
if match {
files[validIdx] = vfs.NewFileInfo(path.Join(relativeBase, fi.Name()), fi.IsDir(), fi.Size(),
fi.ModTime(), true)
validIdx++
}
}
return files[:validIdx], nil
}
func (c *Connection) isListDirWithWildcards(name string) bool {
if strings.ContainsAny(name, "*?[]^") {
lastCommand := c.clientContext.GetLastCommand()
@ -559,3 +544,40 @@ func getPathRelativeTo(base, target string) string {
base = path.Dir(path.Clean(base))
}
}
type patternDirLister struct {
vfs.DirLister
pattern string
lastCommand string
dirName string
connectionPath string
}
func (l *patternDirLister) Next(limit int) ([]os.FileInfo, error) {
for {
files, err := l.DirLister.Next(limit)
if len(files) == 0 {
return files, err
}
validIdx := 0
var relativeBase string
if l.lastCommand != "NLST" {
relativeBase = getPathRelativeTo(l.connectionPath, l.dirName)
}
for _, fi := range files {
match, errMatch := path.Match(l.pattern, fi.Name())
if errMatch != nil {
return nil, errMatch
}
if match {
files[validIdx] = vfs.NewFileInfo(path.Join(relativeBase, fi.Name()), fi.IsDir(), fi.Size(),
fi.ModTime(), true)
validIdx++
}
}
files = files[:validIdx]
if err != nil || len(files) > 0 {
return files, err
}
}
}

View file

@ -74,12 +74,12 @@ func readUserFolder(w http.ResponseWriter, r *http.Request) {
defer common.Connections.Remove(connection.GetID())
name := connection.User.GetCleanedPath(r.URL.Query().Get("path"))
contents, err := connection.ReadDir(name)
lister, err := connection.ReadDir(name)
if err != nil {
sendAPIResponse(w, r, err, "Unable to get directory contents", getMappedStatusCode(err))
sendAPIResponse(w, r, err, "Unable to get directory lister", getMappedStatusCode(err))
return
}
renderAPIDirContents(w, r, contents, false)
renderAPIDirContents(w, lister, false)
}
func createUserDir(w http.ResponseWriter, r *http.Request) {

View file

@ -213,12 +213,12 @@ func (s *httpdServer) readBrowsableShareContents(w http.ResponseWriter, r *http.
}
defer common.Connections.Remove(connection.GetID())
contents, err := connection.ReadDir(name)
lister, err := connection.ReadDir(name)
if err != nil {
sendAPIResponse(w, r, err, "Unable to get directory contents", getMappedStatusCode(err))
sendAPIResponse(w, r, err, "Unable to get directory lister", getMappedStatusCode(err))
return
}
renderAPIDirContents(w, r, contents, true)
renderAPIDirContents(w, lister, true)
}
func (s *httpdServer) downloadBrowsableSharedFile(w http.ResponseWriter, r *http.Request) {

View file

@ -17,6 +17,7 @@ package httpd
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
@ -44,6 +45,7 @@ import (
"github.com/drakkan/sftpgo/v2/internal/plugin"
"github.com/drakkan/sftpgo/v2/internal/smtp"
"github.com/drakkan/sftpgo/v2/internal/util"
"github.com/drakkan/sftpgo/v2/internal/vfs"
)
type pwdChange struct {
@ -280,23 +282,40 @@ func getSearchFilters(w http.ResponseWriter, r *http.Request) (int, int, string,
return limit, offset, order, err
}
func renderAPIDirContents(w http.ResponseWriter, r *http.Request, contents []os.FileInfo, omitNonRegularFiles bool) {
results := make([]map[string]any, 0, len(contents))
for _, info := range contents {
if omitNonRegularFiles && !info.Mode().IsDir() && !info.Mode().IsRegular() {
continue
func renderAPIDirContents(w http.ResponseWriter, lister vfs.DirLister, omitNonRegularFiles bool) {
defer lister.Close()
dataGetter := func(limit, _ int) ([]byte, int, error) {
contents, err := lister.Next(limit)
if errors.Is(err, io.EOF) {
err = nil
}
res := make(map[string]any)
res["name"] = info.Name()
if info.Mode().IsRegular() {
res["size"] = info.Size()
if err != nil {
return nil, 0, err
}
res["mode"] = info.Mode()
res["last_modified"] = info.ModTime().UTC().Format(time.RFC3339)
results = append(results, res)
results := make([]map[string]any, 0, len(contents))
for _, info := range contents {
if omitNonRegularFiles && !info.Mode().IsDir() && !info.Mode().IsRegular() {
continue
}
res := make(map[string]any)
res["name"] = info.Name()
if info.Mode().IsRegular() {
res["size"] = info.Size()
}
res["mode"] = info.Mode()
res["last_modified"] = info.ModTime().UTC().Format(time.RFC3339)
results = append(results, res)
}
data, err := json.Marshal(results)
count := limit
if len(results) == 0 {
count = 0
}
return data, count, err
}
render.JSON(w, r, results)
streamJSONArray(w, defaultQueryLimit, dataGetter)
}
func streamData(w io.Writer, data []byte) {
@ -355,7 +374,7 @@ func renderCompressedFiles(w http.ResponseWriter, conn *Connection, baseDir stri
for _, file := range files {
fullPath := util.CleanPath(path.Join(baseDir, file))
if err := addZipEntry(wr, conn, fullPath, baseDir); err != nil {
if err := addZipEntry(wr, conn, fullPath, baseDir, 0); err != nil {
if share != nil {
dataprovider.UpdateShareLastUse(share, -1) //nolint:errcheck
}
@ -371,7 +390,12 @@ func renderCompressedFiles(w http.ResponseWriter, conn *Connection, baseDir stri
}
}
func addZipEntry(wr *zip.Writer, conn *Connection, entryPath, baseDir string) error {
func addZipEntry(wr *zip.Writer, conn *Connection, entryPath, baseDir string, recursion int) error {
if recursion >= util.MaxRecursion {
conn.Log(logger.LevelDebug, "unable to add zip entry %q, recursion too depth: %d", entryPath, recursion)
return util.ErrRecursionTooDeep
}
recursion++
info, err := conn.Stat(entryPath, 1)
if err != nil {
conn.Log(logger.LevelDebug, "unable to add zip entry %q, stat error: %v", entryPath, err)
@ -392,24 +416,39 @@ func addZipEntry(wr *zip.Writer, conn *Connection, entryPath, baseDir string) er
conn.Log(logger.LevelError, "unable to create zip entry %q: %v", entryPath, err)
return err
}
contents, err := conn.ReadDir(entryPath)
lister, err := conn.ReadDir(entryPath)
if err != nil {
conn.Log(logger.LevelDebug, "unable to add zip entry %q, read dir error: %v", entryPath, err)
conn.Log(logger.LevelDebug, "unable to add zip entry %q, get list dir error: %v", entryPath, err)
return err
}
for _, info := range contents {
fullPath := util.CleanPath(path.Join(entryPath, info.Name()))
if err := addZipEntry(wr, conn, fullPath, baseDir); err != nil {
defer lister.Close()
for {
contents, err := lister.Next(vfs.ListerBatchSize)
finished := errors.Is(err, io.EOF)
if err != nil && !finished {
return err
}
for _, info := range contents {
fullPath := util.CleanPath(path.Join(entryPath, info.Name()))
if err := addZipEntry(wr, conn, fullPath, baseDir, recursion); err != nil {
return err
}
}
if finished {
return nil
}
}
return nil
}
if !info.Mode().IsRegular() {
// we only allow regular files
conn.Log(logger.LevelInfo, "skipping zip entry for non regular file %q", entryPath)
return nil
}
return addFileToZipEntry(wr, conn, entryPath, entryName, info)
}
func addFileToZipEntry(wr *zip.Writer, conn *Connection, entryPath, entryName string, info os.FileInfo) error {
reader, err := conn.getFileReader(entryPath, 0, http.MethodGet)
if err != nil {
conn.Log(logger.LevelDebug, "unable to add zip entry %q, cannot open file: %v", entryPath, err)

View file

@ -88,7 +88,7 @@ func (c *Connection) Stat(name string, mode int) (os.FileInfo, error) {
}
// ReadDir returns a list of directory entries
func (c *Connection) ReadDir(name string) ([]os.FileInfo, error) {
func (c *Connection) ReadDir(name string) (vfs.DirLister, error) {
c.UpdateLastActivity()
return c.ListDir(name)

View file

@ -16028,7 +16028,7 @@ func TestWebGetFiles(t *testing.T) {
setBearerForReq(req, webAPIToken)
rr = executeRequest(req)
checkResponseCode(t, http.StatusNotFound, rr)
assert.Contains(t, rr.Body.String(), "Unable to get directory contents")
assert.Contains(t, rr.Body.String(), "Unable to get directory lister")
req, _ = http.NewRequest(http.MethodGet, webClientFilesPath+"?path="+testFileName, nil)
setJWTCookieForReq(req, webToken)

View file

@ -2196,7 +2196,7 @@ func TestRecoverer(t *testing.T) {
}
func TestStreamJSONArray(t *testing.T) {
dataGetter := func(limit, offset int) ([]byte, int, error) {
dataGetter := func(_, _ int) ([]byte, int, error) {
return nil, 0, nil
}
rr := httptest.NewRecorder()
@ -2268,12 +2268,14 @@ func TestZipErrors(t *testing.T) {
assert.Contains(t, err.Error(), "write error")
}
err = addZipEntry(wr, connection, "/"+filepath.Base(testDir), "/")
err = addZipEntry(wr, connection, "/"+filepath.Base(testDir), "/", 0)
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "write error")
}
err = addZipEntry(wr, connection, "/"+filepath.Base(testDir), "/", 2000)
assert.ErrorIs(t, err, util.ErrRecursionTooDeep)
err = addZipEntry(wr, connection, "/"+filepath.Base(testDir), path.Join("/", filepath.Base(testDir), "dir"))
err = addZipEntry(wr, connection, "/"+filepath.Base(testDir), path.Join("/", filepath.Base(testDir), "dir"), 0)
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "is outside base dir")
}
@ -2282,14 +2284,14 @@ func TestZipErrors(t *testing.T) {
err = os.WriteFile(testFilePath, util.GenerateRandomBytes(65535), os.ModePerm)
assert.NoError(t, err)
err = addZipEntry(wr, connection, path.Join("/", filepath.Base(testDir), filepath.Base(testFilePath)),
"/"+filepath.Base(testDir))
"/"+filepath.Base(testDir), 0)
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "write error")
}
connection.User.Permissions["/"] = []string{dataprovider.PermListItems}
err = addZipEntry(wr, connection, path.Join("/", filepath.Base(testDir), filepath.Base(testFilePath)),
"/"+filepath.Base(testDir))
"/"+filepath.Base(testDir), 0)
assert.ErrorIs(t, err, os.ErrPermission)
// creating a virtual folder to a missing path stat is ok but readdir fails
@ -2301,14 +2303,14 @@ func TestZipErrors(t *testing.T) {
})
connection.User = user
wr = zip.NewWriter(bytes.NewBuffer(make([]byte, 0)))
err = addZipEntry(wr, connection, user.VirtualFolders[0].VirtualPath, "/")
err = addZipEntry(wr, connection, user.VirtualFolders[0].VirtualPath, "/", 0)
assert.Error(t, err)
user.Filters.FilePatterns = append(user.Filters.FilePatterns, sdk.PatternsFilter{
Path: "/",
DeniedPatterns: []string{"*.zip"},
})
err = addZipEntry(wr, connection, "/"+filepath.Base(testDir), "/")
err = addZipEntry(wr, connection, "/"+filepath.Base(testDir), "/", 0)
assert.ErrorIs(t, err, os.ErrPermission)
err = os.RemoveAll(testDir)

View file

@ -958,33 +958,50 @@ func (s *httpdServer) handleShareGetDirContents(w http.ResponseWriter, r *http.R
}
defer common.Connections.Remove(connection.GetID())
contents, err := connection.ReadDir(name)
lister, err := connection.ReadDir(name)
if err != nil {
sendAPIResponse(w, r, err, getI18NErrorString(err, util.I18nErrorDirListGeneric), getMappedStatusCode(err))
return
}
results := make([]map[string]any, 0, len(contents))
for _, info := range contents {
if !info.Mode().IsDir() && !info.Mode().IsRegular() {
continue
defer lister.Close()
dataGetter := func(limit, _ int) ([]byte, int, error) {
contents, err := lister.Next(limit)
if errors.Is(err, io.EOF) {
err = nil
}
res := make(map[string]any)
if info.IsDir() {
res["type"] = "1"
res["size"] = ""
} else {
res["type"] = "2"
res["size"] = info.Size()
if err != nil {
return nil, 0, err
}
res["meta"] = fmt.Sprintf("%v_%v", res["type"], info.Name())
res["name"] = info.Name()
res["url"] = getFileObjectURL(share.GetRelativePath(name), info.Name(),
path.Join(webClientPubSharesPath, share.ShareID, "browse"))
res["last_modified"] = getFileObjectModTime(info.ModTime())
results = append(results, res)
results := make([]map[string]any, 0, len(contents))
for _, info := range contents {
if !info.Mode().IsDir() && !info.Mode().IsRegular() {
continue
}
res := make(map[string]any)
if info.IsDir() {
res["type"] = "1"
res["size"] = ""
} else {
res["type"] = "2"
res["size"] = info.Size()
}
res["meta"] = fmt.Sprintf("%v_%v", res["type"], info.Name())
res["name"] = info.Name()
res["url"] = getFileObjectURL(share.GetRelativePath(name), info.Name(),
path.Join(webClientPubSharesPath, share.ShareID, "browse"))
res["last_modified"] = getFileObjectModTime(info.ModTime())
results = append(results, res)
}
data, err := json.Marshal(results)
count := limit
if len(results) == 0 {
count = 0
}
return data, count, err
}
render.JSON(w, r, results)
streamJSONArray(w, defaultQueryLimit, dataGetter)
}
func (s *httpdServer) handleClientUploadToShare(w http.ResponseWriter, r *http.Request) {
@ -1146,43 +1163,59 @@ func (s *httpdServer) handleClientGetDirContents(w http.ResponseWriter, r *http.
defer common.Connections.Remove(connection.GetID())
name := connection.User.GetCleanedPath(r.URL.Query().Get("path"))
contents, err := connection.ReadDir(name)
lister, err := connection.ReadDir(name)
if err != nil {
statusCode := getMappedStatusCode(err)
sendAPIResponse(w, r, err, i18nListDirMsg(statusCode), statusCode)
return
}
defer lister.Close()
dirTree := r.URL.Query().Get("dirtree") == "1"
results := make([]map[string]any, 0, len(contents))
for _, info := range contents {
res := make(map[string]any)
res["url"] = getFileObjectURL(name, info.Name(), webClientFilesPath)
if info.IsDir() {
res["type"] = "1"
res["size"] = ""
res["dir_path"] = url.QueryEscape(path.Join(name, info.Name()))
} else {
if dirTree {
continue
}
res["type"] = "2"
if info.Mode()&os.ModeSymlink != 0 {
dataGetter := func(limit, _ int) ([]byte, int, error) {
contents, err := lister.Next(limit)
if errors.Is(err, io.EOF) {
err = nil
}
if err != nil {
return nil, 0, err
}
results := make([]map[string]any, 0, len(contents))
for _, info := range contents {
res := make(map[string]any)
res["url"] = getFileObjectURL(name, info.Name(), webClientFilesPath)
if info.IsDir() {
res["type"] = "1"
res["size"] = ""
res["dir_path"] = url.QueryEscape(path.Join(name, info.Name()))
} else {
res["size"] = info.Size()
if info.Size() < httpdMaxEditFileSize {
res["edit_url"] = strings.Replace(res["url"].(string), webClientFilesPath, webClientEditFilePath, 1)
if dirTree {
continue
}
res["type"] = "2"
if info.Mode()&os.ModeSymlink != 0 {
res["size"] = ""
} else {
res["size"] = info.Size()
if info.Size() < httpdMaxEditFileSize {
res["edit_url"] = strings.Replace(res["url"].(string), webClientFilesPath, webClientEditFilePath, 1)
}
}
}
res["meta"] = fmt.Sprintf("%v_%v", res["type"], info.Name())
res["name"] = info.Name()
res["last_modified"] = getFileObjectModTime(info.ModTime())
results = append(results, res)
}
res["meta"] = fmt.Sprintf("%v_%v", res["type"], info.Name())
res["name"] = info.Name()
res["last_modified"] = getFileObjectModTime(info.ModTime())
results = append(results, res)
data, err := json.Marshal(results)
count := limit
if len(results) == 0 {
count = 0
}
return data, count, err
}
render.JSON(w, r, results)
streamJSONArray(w, defaultQueryLimit, dataGetter)
}
func (s *httpdServer) handleClientGetFiles(w http.ResponseWriter, r *http.Request) {
@ -1917,27 +1950,45 @@ func doCheckExist(w http.ResponseWriter, r *http.Request, connection *Connection
return
}
contents, err := connection.ListDir(name)
lister, err := connection.ListDir(name)
if err != nil {
sendAPIResponse(w, r, err, "Unable to get directory contents", getMappedStatusCode(err))
return
}
existing := make([]map[string]any, 0)
for _, info := range contents {
if util.Contains(filesList.Files, info.Name()) {
res := make(map[string]any)
res["name"] = info.Name()
if info.IsDir() {
res["type"] = "1"
res["size"] = ""
} else {
res["type"] = "2"
res["size"] = info.Size()
}
existing = append(existing, res)
defer lister.Close()
dataGetter := func(limit, _ int) ([]byte, int, error) {
contents, err := lister.Next(limit)
if errors.Is(err, io.EOF) {
err = nil
}
if err != nil {
return nil, 0, err
}
existing := make([]map[string]any, 0)
for _, info := range contents {
if util.Contains(filesList.Files, info.Name()) {
res := make(map[string]any)
res["name"] = info.Name()
if info.IsDir() {
res["type"] = "1"
res["size"] = ""
} else {
res["type"] = "2"
res["size"] = info.Size()
}
existing = append(existing, res)
}
}
data, err := json.Marshal(existing)
count := limit
if len(existing) == 0 {
count = 0
}
return data, count, err
}
render.JSON(w, r, existing)
streamJSONArray(w, defaultQueryLimit, dataGetter)
}
func checkShareRedirectURL(next, base string) (bool, string) {

View file

@ -91,8 +91,8 @@ func (s *Service) initLogger() {
// Start initializes and starts the service
func (s *Service) Start(disableAWSInstallationCode bool) error {
s.initLogger()
logger.Info(logSender, "", "starting SFTPGo %v, config dir: %v, config file: %v, log max size: %v log max backups: %v "+
"log max age: %v log level: %v, log compress: %v, log utc time: %v, load data from: %q, grace time: %d secs",
logger.Info(logSender, "", "starting SFTPGo %s, config dir: %s, config file: %s, log max size: %d log max backups: %d "+
"log max age: %d log level: %s, log compress: %t, log utc time: %t, load data from: %q, grace time: %d secs",
version.GetAsString(), s.ConfigDir, s.ConfigFile, s.LogMaxSize, s.LogMaxBackups, s.LogMaxAge, s.LogLevel,
s.LogCompress, s.LogUTCTime, s.LoadDataFrom, graceTime)
// in portable mode we don't read configuration from file

View file

@ -216,16 +216,16 @@ func (c *Connection) Filelist(request *sftp.Request) (sftp.ListerAt, error) {
switch request.Method {
case "List":
files, err := c.ListDir(request.Filepath)
lister, err := c.ListDir(request.Filepath)
if err != nil {
return nil, err
}
modTime := time.Unix(0, 0)
if request.Filepath != "/" || c.folderPrefix != "" {
files = util.PrependFileInfo(files, vfs.NewFileInfo("..", true, 0, modTime, false))
lister.Add(vfs.NewFileInfo("..", true, 0, modTime, false))
}
files = util.PrependFileInfo(files, vfs.NewFileInfo(".", true, 0, modTime, false))
return listerAt(files), nil
lister.Add(vfs.NewFileInfo(".", true, 0, modTime, false))
return lister, nil
case "Stat":
if !c.User.HasPerm(dataprovider.PermListItems, path.Dir(request.Filepath)) {
return nil, sftp.ErrSSHFxPermissionDenied

View file

@ -1294,6 +1294,17 @@ func TestSCPProtocolMessages(t *testing.T) {
if assert.Error(t, err) {
assert.Equal(t, protocolErrorMsg, err.Error())
}
mockSSHChannel = MockChannel{
Buffer: bytes.NewBuffer(respBuffer),
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
ReadError: nil,
WriteError: writeErr,
}
scpCommand.connection.channel = &mockSSHChannel
err = scpCommand.downloadDirs(nil, nil)
assert.ErrorIs(t, err, writeErr)
}
func TestSCPTestDownloadProtocolMessages(t *testing.T) {

View file

@ -384,47 +384,65 @@ func (c *scpCommand) handleRecursiveDownload(fs vfs.Fs, dirPath, virtualPath str
if err != nil {
return err
}
files, err := fs.ReadDir(dirPath)
// dirPath is a fs path, not a virtual path
lister, err := fs.ReadDir(dirPath)
if err != nil {
c.sendErrorMessage(fs, err)
return err
}
files = c.connection.User.FilterListDir(files, fs.GetRelativePath(dirPath))
defer lister.Close()
vdirs := c.connection.User.GetVirtualFoldersInfo(virtualPath)
var dirs []string
for _, file := range files {
filePath := fs.GetRelativePath(fs.Join(dirPath, file.Name()))
if file.Mode().IsRegular() || file.Mode()&os.ModeSymlink != 0 {
err = c.handleDownload(filePath)
if err != nil {
break
}
} else if file.IsDir() {
dirs = append(dirs, filePath)
for {
files, err := lister.Next(vfs.ListerBatchSize)
finished := errors.Is(err, io.EOF)
if err != nil && !finished {
c.sendErrorMessage(fs, err)
return err
}
}
if err != nil {
c.sendErrorMessage(fs, err)
return err
}
for _, dir := range dirs {
err = c.handleDownload(dir)
if err != nil {
files = c.connection.User.FilterListDir(files, fs.GetRelativePath(dirPath))
if len(vdirs) > 0 {
files = append(files, vdirs...)
vdirs = nil
}
for _, file := range files {
filePath := fs.GetRelativePath(fs.Join(dirPath, file.Name()))
if file.Mode().IsRegular() || file.Mode()&os.ModeSymlink != 0 {
err = c.handleDownload(filePath)
if err != nil {
c.sendErrorMessage(fs, err)
return err
}
} else if file.IsDir() {
dirs = append(dirs, filePath)
}
}
if finished {
break
}
}
if err != nil {
lister.Close()
return c.downloadDirs(fs, dirs)
}
err = errors.New("unable to send directory for non recursive copy")
c.sendErrorMessage(nil, err)
return err
}
func (c *scpCommand) downloadDirs(fs vfs.Fs, dirs []string) error {
for _, dir := range dirs {
if err := c.handleDownload(dir); err != nil {
c.sendErrorMessage(fs, err)
return err
}
err = c.sendProtocolMessage("E\n")
if err != nil {
return err
}
return c.readConfirmationMessage()
}
err = fmt.Errorf("unable to send directory for non recursive copy")
c.sendErrorMessage(nil, err)
return err
if err := c.sendProtocolMessage("E\n"); err != nil {
return err
}
return c.readConfirmationMessage()
}
func (c *scpCommand) sendDownloadFileData(fs vfs.Fs, filePath string, stat os.FileInfo, transfer *transfer) error {

View file

@ -376,6 +376,7 @@ func (c *Configuration) Initialize(configDir string) error {
c.loadModuli(configDir)
sftp.SetSFTPExtensions(sftpExtensions...) //nolint:errcheck // we configure valid SFTP Extensions so we cannot get an error
sftp.MaxFilelist = vfs.ListerBatchSize
if err := c.configureSecurityOptions(serverConfig); err != nil {
return err

View file

@ -15,6 +15,7 @@
package util
import (
"errors"
"fmt"
)
@ -24,12 +25,16 @@ const (
"sftpgo serve -c \"<path to dir containing the default config file and templates directory>\""
)
// MaxRecursion defines the maximum number of allowed recursions
const MaxRecursion = 1000
// errors definitions
var (
ErrValidation = NewValidationError("")
ErrNotFound = NewRecordNotFoundError("")
ErrMethodDisabled = NewMethodDisabledError("")
ErrGeneric = NewGenericError("")
ErrValidation = NewValidationError("")
ErrNotFound = NewRecordNotFoundError("")
ErrMethodDisabled = NewMethodDisabledError("")
ErrGeneric = NewGenericError("")
ErrRecursionTooDeep = errors.New("recursion too deep")
)
// ValidationError raised if input data is not valid

View file

@ -56,6 +56,10 @@ const (
azFolderKey = "hdi_isfolder"
)
var (
azureBlobDefaultPageSize = int32(5000)
)
// AzureBlobFs is a Fs implementation for Azure Blob storage.
type AzureBlobFs struct {
connectionID string
@ -308,7 +312,7 @@ func (fs *AzureBlobFs) Rename(source, target string) (int, int64, error) {
if err != nil {
return -1, -1, err
}
return fs.renameInternal(source, target, fi)
return fs.renameInternal(source, target, fi, 0)
}
// Remove removes the named file or (empty) directory.
@ -408,76 +412,23 @@ func (*AzureBlobFs) Truncate(_ string, _ int64) error {
// ReadDir reads the directory named by dirname and returns
// a list of directory entries.
func (fs *AzureBlobFs) ReadDir(dirname string) ([]os.FileInfo, error) {
var result []os.FileInfo
func (fs *AzureBlobFs) ReadDir(dirname string) (DirLister, error) {
// dirname must be already cleaned
prefix := fs.getPrefix(dirname)
modTimes, err := getFolderModTimes(fs.getStorageID(), dirname)
if err != nil {
return result, err
}
prefixes := make(map[string]bool)
pager := fs.containerClient.NewListBlobsHierarchyPager("/", &container.ListBlobsHierarchyOptions{
Include: container.ListBlobsInclude{
//Metadata: true,
},
Prefix: &prefix,
Prefix: &prefix,
MaxResults: &azureBlobDefaultPageSize,
})
for pager.More() {
ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout))
defer cancelFn()
resp, err := pager.NextPage(ctx)
if err != nil {
metric.AZListObjectsCompleted(err)
return result, err
}
for _, blobPrefix := range resp.ListBlobsHierarchySegmentResponse.Segment.BlobPrefixes {
name := util.GetStringFromPointer(blobPrefix.Name)
// we don't support prefixes == "/" this will be sent if a key starts with "/"
if name == "" || name == "/" {
continue
}
// sometime we have duplicate prefixes, maybe an Azurite bug
name = strings.TrimPrefix(name, prefix)
if _, ok := prefixes[strings.TrimSuffix(name, "/")]; ok {
continue
}
result = append(result, NewFileInfo(name, true, 0, time.Unix(0, 0), false))
prefixes[strings.TrimSuffix(name, "/")] = true
}
for _, blobItem := range resp.ListBlobsHierarchySegmentResponse.Segment.BlobItems {
name := util.GetStringFromPointer(blobItem.Name)
name = strings.TrimPrefix(name, prefix)
size := int64(0)
isDir := false
modTime := time.Unix(0, 0)
if blobItem.Properties != nil {
size = util.GetIntFromPointer(blobItem.Properties.ContentLength)
modTime = util.GetTimeFromPointer(blobItem.Properties.LastModified)
contentType := util.GetStringFromPointer(blobItem.Properties.ContentType)
isDir = checkDirectoryMarkers(contentType, blobItem.Metadata)
if isDir {
// check if the dir is already included, it will be sent as blob prefix if it contains at least one item
if _, ok := prefixes[name]; ok {
continue
}
prefixes[name] = true
}
}
if t, ok := modTimes[name]; ok {
modTime = util.GetTimeFromMsecSinceEpoch(t)
}
result = append(result, NewFileInfo(name, isDir, size, modTime, false))
}
}
metric.AZListObjectsCompleted(nil)
return result, nil
return &azureBlobDirLister{
paginator: pager,
timeout: fs.ctxTimeout,
prefix: prefix,
prefixes: make(map[string]bool),
}, nil
}
// IsUploadResumeSupported returns true if resuming uploads is supported.
@ -569,7 +520,8 @@ func (fs *AzureBlobFs) getFileNamesInPrefix(fsPrefix string) (map[string]bool, e
Include: container.ListBlobsInclude{
//Metadata: true,
},
Prefix: &prefix,
Prefix: &prefix,
MaxResults: &azureBlobDefaultPageSize,
})
for pager.More() {
@ -615,7 +567,8 @@ func (fs *AzureBlobFs) GetDirSize(dirname string) (int, int64, error) {
Include: container.ListBlobsInclude{
Metadata: true,
},
Prefix: &prefix,
Prefix: &prefix,
MaxResults: &azureBlobDefaultPageSize,
})
for pager.More() {
@ -684,7 +637,8 @@ func (fs *AzureBlobFs) Walk(root string, walkFn filepath.WalkFunc) error {
Include: container.ListBlobsInclude{
Metadata: true,
},
Prefix: &prefix,
Prefix: &prefix,
MaxResults: &azureBlobDefaultPageSize,
})
for pager.More() {
@ -863,7 +817,7 @@ func (fs *AzureBlobFs) copyFileInternal(source, target string) error {
return nil
}
func (fs *AzureBlobFs) renameInternal(source, target string, fi os.FileInfo) (int, int64, error) {
func (fs *AzureBlobFs) renameInternal(source, target string, fi os.FileInfo, recursion int) (int, int64, error) {
var numFiles int
var filesSize int64
@ -881,24 +835,12 @@ func (fs *AzureBlobFs) renameInternal(source, target string, fi os.FileInfo) (in
return numFiles, filesSize, err
}
if renameMode == 1 {
entries, err := fs.ReadDir(source)
files, size, err := doRecursiveRename(fs, source, target, fs.renameInternal, recursion)
numFiles += files
filesSize += size
if err != nil {
return numFiles, filesSize, err
}
for _, info := range entries {
sourceEntry := fs.Join(source, info.Name())
targetEntry := fs.Join(target, info.Name())
files, size, err := fs.renameInternal(sourceEntry, targetEntry, info)
if err != nil {
if fs.IsNotExist(err) {
fsLog(fs, logger.LevelInfo, "skipping rename for %q: %v", sourceEntry, err)
continue
}
return numFiles, filesSize, err
}
numFiles += files
filesSize += size
}
}
} else {
if err := fs.copyFileInternal(source, target); err != nil {
@ -1312,3 +1254,80 @@ func (b *bufferAllocator) free() {
b.available = nil
b.finalized = true
}
type azureBlobDirLister struct {
baseDirLister
paginator *runtime.Pager[container.ListBlobsHierarchyResponse]
timeout time.Duration
prefix string
prefixes map[string]bool
metricUpdated bool
}
func (l *azureBlobDirLister) Next(limit int) ([]os.FileInfo, error) {
if limit <= 0 {
return nil, errInvalidDirListerLimit
}
if len(l.cache) >= limit {
return l.returnFromCache(limit), nil
}
if !l.paginator.More() {
if !l.metricUpdated {
l.metricUpdated = true
metric.AZListObjectsCompleted(nil)
}
return l.returnFromCache(limit), io.EOF
}
ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(l.timeout))
defer cancelFn()
page, err := l.paginator.NextPage(ctx)
if err != nil {
metric.AZListObjectsCompleted(err)
return l.cache, err
}
for _, blobPrefix := range page.ListBlobsHierarchySegmentResponse.Segment.BlobPrefixes {
name := util.GetStringFromPointer(blobPrefix.Name)
// we don't support prefixes == "/" this will be sent if a key starts with "/"
if name == "" || name == "/" {
continue
}
// sometime we have duplicate prefixes, maybe an Azurite bug
name = strings.TrimPrefix(name, l.prefix)
if _, ok := l.prefixes[strings.TrimSuffix(name, "/")]; ok {
continue
}
l.cache = append(l.cache, NewFileInfo(name, true, 0, time.Unix(0, 0), false))
l.prefixes[strings.TrimSuffix(name, "/")] = true
}
for _, blobItem := range page.ListBlobsHierarchySegmentResponse.Segment.BlobItems {
name := util.GetStringFromPointer(blobItem.Name)
name = strings.TrimPrefix(name, l.prefix)
size := int64(0)
isDir := false
modTime := time.Unix(0, 0)
if blobItem.Properties != nil {
size = util.GetIntFromPointer(blobItem.Properties.ContentLength)
modTime = util.GetTimeFromPointer(blobItem.Properties.LastModified)
contentType := util.GetStringFromPointer(blobItem.Properties.ContentType)
isDir = checkDirectoryMarkers(contentType, blobItem.Metadata)
if isDir {
// check if the dir is already included, it will be sent as blob prefix if it contains at least one item
if _, ok := l.prefixes[name]; ok {
continue
}
l.prefixes[name] = true
}
}
l.cache = append(l.cache, NewFileInfo(name, isDir, size, modTime, false))
}
return l.returnFromCache(limit), nil
}
func (l *azureBlobDirLister) Close() error {
clear(l.prefixes)
return l.baseDirLister.Close()
}

View file

@ -221,21 +221,16 @@ func (*CryptFs) Truncate(_ string, _ int64) error {
// ReadDir reads the directory named by dirname and returns
// a list of directory entries.
func (fs *CryptFs) ReadDir(dirname string) ([]os.FileInfo, error) {
func (fs *CryptFs) ReadDir(dirname string) (DirLister, error) {
f, err := os.Open(dirname)
if err != nil {
if isInvalidNameError(err) {
err = os.ErrNotExist
}
return nil, err
}
list, err := f.Readdir(-1)
f.Close()
if err != nil {
return nil, err
}
result := make([]os.FileInfo, 0, len(list))
for _, info := range list {
result = append(result, fs.ConvertFileInfo(info))
}
return result, nil
return &cryptFsDirLister{f}, nil
}
// IsUploadResumeSupported returns false sio does not support random access writes
@ -289,20 +284,7 @@ func (fs *CryptFs) getSIOConfig(key [32]byte) sio.Config {
// ConvertFileInfo returns a FileInfo with the decrypted size
func (fs *CryptFs) ConvertFileInfo(info os.FileInfo) os.FileInfo {
if !info.Mode().IsRegular() {
return info
}
size := info.Size()
if size >= headerV10Size {
size -= headerV10Size
decryptedSize, err := sio.DecryptedSize(uint64(size))
if err == nil {
size = int64(decryptedSize)
}
} else {
size = 0
}
return NewFileInfo(info.Name(), info.IsDir(), size, info.ModTime(), false)
return convertCryptFsInfo(info)
}
func (fs *CryptFs) getFileAndEncryptionKey(name string) (*os.File, [32]byte, error) {
@ -366,6 +348,23 @@ func isZeroBytesDownload(f *os.File, offset int64) (bool, error) {
return false, nil
}
func convertCryptFsInfo(info os.FileInfo) os.FileInfo {
if !info.Mode().IsRegular() {
return info
}
size := info.Size()
if size >= headerV10Size {
size -= headerV10Size
decryptedSize, err := sio.DecryptedSize(uint64(size))
if err == nil {
size = int64(decryptedSize)
}
} else {
size = 0
}
return NewFileInfo(info.Name(), info.IsDir(), size, info.ModTime(), false)
}
type encryptedFileHeader struct {
version byte
nonce []byte
@ -400,3 +399,22 @@ type cryptedFileWrapper struct {
func (w *cryptedFileWrapper) ReadAt(p []byte, offset int64) (n int, err error) {
return w.File.ReadAt(p, offset+headerV10Size)
}
type cryptFsDirLister struct {
f *os.File
}
func (l *cryptFsDirLister) Next(limit int) ([]os.FileInfo, error) {
if limit <= 0 {
return nil, errInvalidDirListerLimit
}
files, err := l.f.Readdir(limit)
for idx := range files {
files[idx] = convertCryptFsInfo(files[idx])
}
return files, err
}
func (l *cryptFsDirLister) Close() error {
return l.f.Close()
}

View file

@ -266,7 +266,7 @@ func (fs *GCSFs) Rename(source, target string) (int, int64, error) {
if err != nil {
return -1, -1, err
}
return fs.renameInternal(source, target, fi)
return fs.renameInternal(source, target, fi, 0)
}
// Remove removes the named file or (empty) directory.
@ -369,80 +369,23 @@ func (*GCSFs) Truncate(_ string, _ int64) error {
// ReadDir reads the directory named by dirname and returns
// a list of directory entries.
func (fs *GCSFs) ReadDir(dirname string) ([]os.FileInfo, error) {
var result []os.FileInfo
func (fs *GCSFs) ReadDir(dirname string) (DirLister, error) {
// dirname must be already cleaned
prefix := fs.getPrefix(dirname)
query := &storage.Query{Prefix: prefix, Delimiter: "/"}
err := query.SetAttrSelection(gcsDefaultFieldsSelection)
if err != nil {
return nil, err
}
modTimes, err := getFolderModTimes(fs.getStorageID(), dirname)
if err != nil {
return result, err
}
prefixes := make(map[string]bool)
ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxLongTimeout))
defer cancelFn()
bkt := fs.svc.Bucket(fs.config.Bucket)
it := bkt.Objects(ctx, query)
pager := iterator.NewPager(it, defaultGCSPageSize, "")
for {
var objects []*storage.ObjectAttrs
pageToken, err := pager.NextPage(&objects)
if err != nil {
metric.GCSListObjectsCompleted(err)
return result, err
}
for _, attrs := range objects {
if attrs.Prefix != "" {
name, _ := fs.resolve(attrs.Prefix, prefix, attrs.ContentType)
if name == "" {
continue
}
if _, ok := prefixes[name]; ok {
continue
}
result = append(result, NewFileInfo(name, true, 0, time.Unix(0, 0), false))
prefixes[name] = true
} else {
name, isDir := fs.resolve(attrs.Name, prefix, attrs.ContentType)
if name == "" {
continue
}
if !attrs.Deleted.IsZero() {
continue
}
if isDir {
// check if the dir is already included, it will be sent as blob prefix if it contains at least one item
if _, ok := prefixes[name]; ok {
continue
}
prefixes[name] = true
}
modTime := attrs.Updated
if t, ok := modTimes[name]; ok {
modTime = util.GetTimeFromMsecSinceEpoch(t)
}
result = append(result, NewFileInfo(name, isDir, attrs.Size, modTime, false))
}
}
objects = nil
if pageToken == "" {
break
}
}
metric.GCSListObjectsCompleted(nil)
return result, nil
return &gcsDirLister{
bucket: bkt,
query: query,
timeout: fs.ctxTimeout,
prefix: prefix,
prefixes: make(map[string]bool),
}, nil
}
// IsUploadResumeSupported returns true if resuming uploads is supported.
@ -853,7 +796,7 @@ func (fs *GCSFs) copyFileInternal(source, target string) error {
return err
}
func (fs *GCSFs) renameInternal(source, target string, fi os.FileInfo) (int, int64, error) {
func (fs *GCSFs) renameInternal(source, target string, fi os.FileInfo, recursion int) (int, int64, error) {
var numFiles int
var filesSize int64
@ -871,24 +814,12 @@ func (fs *GCSFs) renameInternal(source, target string, fi os.FileInfo) (int, int
return numFiles, filesSize, err
}
if renameMode == 1 {
entries, err := fs.ReadDir(source)
files, size, err := doRecursiveRename(fs, source, target, fs.renameInternal, recursion)
numFiles += files
filesSize += size
if err != nil {
return numFiles, filesSize, err
}
for _, info := range entries {
sourceEntry := fs.Join(source, info.Name())
targetEntry := fs.Join(target, info.Name())
files, size, err := fs.renameInternal(sourceEntry, targetEntry, info)
if err != nil {
if fs.IsNotExist(err) {
fsLog(fs, logger.LevelInfo, "skipping rename for %q: %v", sourceEntry, err)
continue
}
return numFiles, filesSize, err
}
numFiles += files
filesSize += size
}
}
} else {
if err := fs.copyFileInternal(source, target); err != nil {
@ -1010,3 +941,97 @@ func (*GCSFs) getTempObject(name string) string {
func (fs *GCSFs) getStorageID() string {
return fmt.Sprintf("gs://%v", fs.config.Bucket)
}
type gcsDirLister struct {
baseDirLister
bucket *storage.BucketHandle
query *storage.Query
timeout time.Duration
nextPageToken string
noMorePages bool
prefix string
prefixes map[string]bool
metricUpdated bool
}
func (l *gcsDirLister) resolve(name, contentType string) (string, bool) {
result := strings.TrimPrefix(name, l.prefix)
isDir := strings.HasSuffix(result, "/")
if isDir {
result = strings.TrimSuffix(result, "/")
}
if contentType == dirMimeType {
isDir = true
}
return result, isDir
}
func (l *gcsDirLister) Next(limit int) ([]os.FileInfo, error) {
if limit <= 0 {
return nil, errInvalidDirListerLimit
}
if len(l.cache) >= limit {
return l.returnFromCache(limit), nil
}
if l.noMorePages {
if !l.metricUpdated {
l.metricUpdated = true
metric.GCSListObjectsCompleted(nil)
}
return l.returnFromCache(limit), io.EOF
}
ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(l.timeout))
defer cancelFn()
it := l.bucket.Objects(ctx, l.query)
paginator := iterator.NewPager(it, defaultGCSPageSize, l.nextPageToken)
var objects []*storage.ObjectAttrs
pageToken, err := paginator.NextPage(&objects)
if err != nil {
metric.GCSListObjectsCompleted(err)
return l.cache, err
}
for _, attrs := range objects {
if attrs.Prefix != "" {
name, _ := l.resolve(attrs.Prefix, attrs.ContentType)
if name == "" {
continue
}
if _, ok := l.prefixes[name]; ok {
continue
}
l.cache = append(l.cache, NewFileInfo(name, true, 0, time.Unix(0, 0), false))
l.prefixes[name] = true
} else {
name, isDir := l.resolve(attrs.Name, attrs.ContentType)
if name == "" {
continue
}
if !attrs.Deleted.IsZero() {
continue
}
if isDir {
// check if the dir is already included, it will be sent as blob prefix if it contains at least one item
if _, ok := l.prefixes[name]; ok {
continue
}
l.prefixes[name] = true
}
l.cache = append(l.cache, NewFileInfo(name, isDir, attrs.Size, attrs.Updated, false))
}
}
l.nextPageToken = pageToken
l.noMorePages = (l.nextPageToken == "")
return l.returnFromCache(limit), nil
}
func (l *gcsDirLister) Close() error {
clear(l.prefixes)
return l.baseDirLister.Close()
}

View file

@ -488,7 +488,7 @@ func (fs *HTTPFs) Truncate(name string, size int64) error {
// ReadDir reads the directory named by dirname and returns
// a list of directory entries.
func (fs *HTTPFs) ReadDir(dirname string) ([]os.FileInfo, error) {
func (fs *HTTPFs) ReadDir(dirname string) (DirLister, error) {
ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout))
defer cancelFn()
@ -511,7 +511,7 @@ func (fs *HTTPFs) ReadDir(dirname string) ([]os.FileInfo, error) {
for _, stat := range response {
result = append(result, stat.getFileInfo())
}
return result, nil
return &baseDirLister{result}, nil
}
// IsUploadResumeSupported returns true if resuming uploads is supported.
@ -731,19 +731,33 @@ func (fs *HTTPFs) walk(filePath string, info fs.FileInfo, walkFn filepath.WalkFu
if !info.IsDir() {
return walkFn(filePath, info, nil)
}
files, err := fs.ReadDir(filePath)
lister, err := fs.ReadDir(filePath)
err1 := walkFn(filePath, info, err)
if err != nil || err1 != nil {
if err == nil {
lister.Close()
}
return err1
}
for _, fi := range files {
objName := path.Join(filePath, fi.Name())
err = fs.walk(objName, fi, walkFn)
if err != nil {
defer lister.Close()
for {
files, err := lister.Next(ListerBatchSize)
finished := errors.Is(err, io.EOF)
if err != nil && !finished {
return err
}
for _, fi := range files {
objName := path.Join(filePath, fi.Name())
err = fs.walk(objName, fi, walkFn)
if err != nil {
return err
}
}
if finished {
return nil
}
}
return nil
}
func getErrorFromResponseCode(code int) error {

View file

@ -190,7 +190,7 @@ func (fs *OsFs) Rename(source, target string) (int, int64, error) {
}
err = fscopy.Copy(source, target, fscopy.Options{
OnSymlink: func(src string) fscopy.SymlinkAction {
OnSymlink: func(_ string) fscopy.SymlinkAction {
return fscopy.Skip
},
CopyBufferSize: readBufferSize,
@ -258,7 +258,7 @@ func (*OsFs) Truncate(name string, size int64) error {
// ReadDir reads the directory named by dirname and returns
// a list of directory entries.
func (*OsFs) ReadDir(dirname string) ([]os.FileInfo, error) {
func (*OsFs) ReadDir(dirname string) (DirLister, error) {
f, err := os.Open(dirname)
if err != nil {
if isInvalidNameError(err) {
@ -266,12 +266,7 @@ func (*OsFs) ReadDir(dirname string) ([]os.FileInfo, error) {
}
return nil, err
}
list, err := f.Readdir(-1)
f.Close()
if err != nil {
return nil, err
}
return list, nil
return &osFsDirLister{f}, nil
}
// IsUploadResumeSupported returns true if resuming uploads is supported
@ -599,3 +594,18 @@ func (fs *OsFs) useWriteBuffering(flag int) bool {
}
return true
}
type osFsDirLister struct {
f *os.File
}
func (l *osFsDirLister) Next(limit int) ([]os.FileInfo, error) {
if limit <= 0 {
return nil, errInvalidDirListerLimit
}
return l.f.Readdir(limit)
}
func (l *osFsDirLister) Close() error {
return l.f.Close()
}

View file

@ -22,6 +22,7 @@ import (
"crypto/tls"
"errors"
"fmt"
"io"
"mime"
"net"
"net/http"
@ -61,7 +62,8 @@ const (
)
var (
s3DirMimeTypes = []string{s3DirMimeType, "httpd/unix-directory"}
s3DirMimeTypes = []string{s3DirMimeType, "httpd/unix-directory"}
s3DefaultPageSize = int32(5000)
)
// S3Fs is a Fs implementation for AWS S3 compatible object storages
@ -337,7 +339,7 @@ func (fs *S3Fs) Rename(source, target string) (int, int64, error) {
if err != nil {
return -1, -1, err
}
return fs.renameInternal(source, target, fi)
return fs.renameInternal(source, target, fi, 0)
}
// Remove removes the named file or (empty) directory.
@ -426,68 +428,22 @@ func (*S3Fs) Truncate(_ string, _ int64) error {
// ReadDir reads the directory named by dirname and returns
// a list of directory entries.
func (fs *S3Fs) ReadDir(dirname string) ([]os.FileInfo, error) {
var result []os.FileInfo
func (fs *S3Fs) ReadDir(dirname string) (DirLister, error) {
// dirname must be already cleaned
prefix := fs.getPrefix(dirname)
modTimes, err := getFolderModTimes(fs.getStorageID(), dirname)
if err != nil {
return result, err
}
prefixes := make(map[string]bool)
paginator := s3.NewListObjectsV2Paginator(fs.svc, &s3.ListObjectsV2Input{
Bucket: aws.String(fs.config.Bucket),
Prefix: aws.String(prefix),
Delimiter: aws.String("/"),
MaxKeys: &s3DefaultPageSize,
})
for paginator.HasMorePages() {
ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout))
defer cancelFn()
page, err := paginator.NextPage(ctx)
if err != nil {
metric.S3ListObjectsCompleted(err)
return result, err
}
for _, p := range page.CommonPrefixes {
// prefixes have a trailing slash
name, _ := fs.resolve(p.Prefix, prefix)
if name == "" {
continue
}
if _, ok := prefixes[name]; ok {
continue
}
result = append(result, NewFileInfo(name, true, 0, time.Unix(0, 0), false))
prefixes[name] = true
}
for _, fileObject := range page.Contents {
objectModTime := util.GetTimeFromPointer(fileObject.LastModified)
objectSize := util.GetIntFromPointer(fileObject.Size)
name, isDir := fs.resolve(fileObject.Key, prefix)
if name == "" || name == "/" {
continue
}
if isDir {
if _, ok := prefixes[name]; ok {
continue
}
prefixes[name] = true
}
if t, ok := modTimes[name]; ok {
objectModTime = util.GetTimeFromMsecSinceEpoch(t)
}
result = append(result, NewFileInfo(name, (isDir && objectSize == 0), objectSize,
objectModTime, false))
}
}
metric.S3ListObjectsCompleted(nil)
return result, nil
return &s3DirLister{
paginator: paginator,
timeout: fs.ctxTimeout,
prefix: prefix,
prefixes: make(map[string]bool),
}, nil
}
// IsUploadResumeSupported returns true if resuming uploads is supported.
@ -574,6 +530,7 @@ func (fs *S3Fs) getFileNamesInPrefix(fsPrefix string) (map[string]bool, error) {
Bucket: aws.String(fs.config.Bucket),
Prefix: aws.String(prefix),
Delimiter: aws.String("/"),
MaxKeys: &s3DefaultPageSize,
})
for paginator.HasMorePages() {
@ -614,8 +571,9 @@ func (fs *S3Fs) GetDirSize(dirname string) (int, int64, error) {
size := int64(0)
paginator := s3.NewListObjectsV2Paginator(fs.svc, &s3.ListObjectsV2Input{
Bucket: aws.String(fs.config.Bucket),
Prefix: aws.String(prefix),
Bucket: aws.String(fs.config.Bucket),
Prefix: aws.String(prefix),
MaxKeys: &s3DefaultPageSize,
})
for paginator.HasMorePages() {
@ -679,8 +637,9 @@ func (fs *S3Fs) Walk(root string, walkFn filepath.WalkFunc) error {
prefix := fs.getPrefix(root)
paginator := s3.NewListObjectsV2Paginator(fs.svc, &s3.ListObjectsV2Input{
Bucket: aws.String(fs.config.Bucket),
Prefix: aws.String(prefix),
Bucket: aws.String(fs.config.Bucket),
Prefix: aws.String(prefix),
MaxKeys: &s3DefaultPageSize,
})
for paginator.HasMorePages() {
@ -797,7 +756,7 @@ func (fs *S3Fs) copyFileInternal(source, target string, fileSize int64) error {
return err
}
func (fs *S3Fs) renameInternal(source, target string, fi os.FileInfo) (int, int64, error) {
func (fs *S3Fs) renameInternal(source, target string, fi os.FileInfo, recursion int) (int, int64, error) {
var numFiles int
var filesSize int64
@ -815,24 +774,12 @@ func (fs *S3Fs) renameInternal(source, target string, fi os.FileInfo) (int, int6
return numFiles, filesSize, err
}
if renameMode == 1 {
entries, err := fs.ReadDir(source)
files, size, err := doRecursiveRename(fs, source, target, fs.renameInternal, recursion)
numFiles += files
filesSize += size
if err != nil {
return numFiles, filesSize, err
}
for _, info := range entries {
sourceEntry := fs.Join(source, info.Name())
targetEntry := fs.Join(target, info.Name())
files, size, err := fs.renameInternal(sourceEntry, targetEntry, info)
if err != nil {
if fs.IsNotExist(err) {
fsLog(fs, logger.LevelInfo, "skipping rename for %q: %v", sourceEntry, err)
continue
}
return numFiles, filesSize, err
}
numFiles += files
filesSize += size
}
}
} else {
if err := fs.copyFileInternal(source, target, fi.Size()); err != nil {
@ -1114,6 +1061,81 @@ func (fs *S3Fs) getStorageID() string {
return fmt.Sprintf("s3://%v", fs.config.Bucket)
}
type s3DirLister struct {
baseDirLister
paginator *s3.ListObjectsV2Paginator
timeout time.Duration
prefix string
prefixes map[string]bool
metricUpdated bool
}
func (l *s3DirLister) resolve(name *string) (string, bool) {
result := strings.TrimPrefix(util.GetStringFromPointer(name), l.prefix)
isDir := strings.HasSuffix(result, "/")
if isDir {
result = strings.TrimSuffix(result, "/")
}
return result, isDir
}
func (l *s3DirLister) Next(limit int) ([]os.FileInfo, error) {
if limit <= 0 {
return nil, errInvalidDirListerLimit
}
if len(l.cache) >= limit {
return l.returnFromCache(limit), nil
}
if !l.paginator.HasMorePages() {
if !l.metricUpdated {
l.metricUpdated = true
metric.S3ListObjectsCompleted(nil)
}
return l.returnFromCache(limit), io.EOF
}
ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(l.timeout))
defer cancelFn()
page, err := l.paginator.NextPage(ctx)
if err != nil {
metric.S3ListObjectsCompleted(err)
return l.cache, err
}
for _, p := range page.CommonPrefixes {
// prefixes have a trailing slash
name, _ := l.resolve(p.Prefix)
if name == "" {
continue
}
if _, ok := l.prefixes[name]; ok {
continue
}
l.cache = append(l.cache, NewFileInfo(name, true, 0, time.Unix(0, 0), false))
l.prefixes[name] = true
}
for _, fileObject := range page.Contents {
objectModTime := util.GetTimeFromPointer(fileObject.LastModified)
objectSize := util.GetIntFromPointer(fileObject.Size)
name, isDir := l.resolve(fileObject.Key)
if name == "" || name == "/" {
continue
}
if isDir {
if _, ok := l.prefixes[name]; ok {
continue
}
l.prefixes[name] = true
}
l.cache = append(l.cache, NewFileInfo(name, (isDir && objectSize == 0), objectSize, objectModTime, false))
}
return l.returnFromCache(limit), nil
}
func (l *s3DirLister) Close() error {
return l.baseDirLister.Close()
}
func getAWSHTTPClient(timeout int, idleConnectionTimeout time.Duration, skipTLSVerify bool) *awshttp.BuildableClient {
c := awshttp.NewBuildableClient().
WithDialerOptions(func(d *net.Dialer) {

View file

@ -541,12 +541,16 @@ func (fs *SFTPFs) Truncate(name string, size int64) error {
// ReadDir reads the directory named by dirname and returns
// a list of directory entries.
func (fs *SFTPFs) ReadDir(dirname string) ([]os.FileInfo, error) {
func (fs *SFTPFs) ReadDir(dirname string) (DirLister, error) {
client, err := fs.conn.getClient()
if err != nil {
return nil, err
}
return client.ReadDir(dirname)
files, err := client.ReadDir(dirname)
if err != nil {
return nil, err
}
return &baseDirLister{files}, nil
}
// IsUploadResumeSupported returns true if resuming uploads is supported.

View file

@ -45,6 +45,8 @@ const (
gcsfsName = "GCSFs"
azBlobFsName = "AzureBlobFs"
preResumeTimeout = 90 * time.Second
// ListerBatchSize defines the default limit for DirLister implementations
ListerBatchSize = 1000
)
// Additional checks for files
@ -58,14 +60,15 @@ var (
// ErrStorageSizeUnavailable is returned if the storage backend does not support getting the size
ErrStorageSizeUnavailable = errors.New("unable to get available size for this storage backend")
// ErrVfsUnsupported defines the error for an unsupported VFS operation
ErrVfsUnsupported = errors.New("not supported")
tempPath string
sftpFingerprints []string
allowSelfConnections int
renameMode int
readMetadata int
resumeMaxSize int64
uploadMode int
ErrVfsUnsupported = errors.New("not supported")
errInvalidDirListerLimit = errors.New("dir lister: invalid limit, must be > 0")
tempPath string
sftpFingerprints []string
allowSelfConnections int
renameMode int
readMetadata int
resumeMaxSize int64
uploadMode int
)
// SetAllowSelfConnections sets the desired behaviour for self connections
@ -125,7 +128,7 @@ type Fs interface {
Chmod(name string, mode os.FileMode) error
Chtimes(name string, atime, mtime time.Time, isUploading bool) error
Truncate(name string, size int64) error
ReadDir(dirname string) ([]os.FileInfo, error)
ReadDir(dirname string) (DirLister, error)
Readlink(name string) (string, error)
IsUploadResumeSupported() bool
IsConditionalUploadResumeSupported(size int64) bool
@ -199,11 +202,47 @@ type PipeReader interface {
Metadata() map[string]string
}
// DirLister defines an interface for a directory lister
type DirLister interface {
Next(limit int) ([]os.FileInfo, error)
Close() error
}
// Metadater defines an interface to implement to return metadata for a file
type Metadater interface {
Metadata() map[string]string
}
type baseDirLister struct {
cache []os.FileInfo
}
func (l *baseDirLister) Next(limit int) ([]os.FileInfo, error) {
if limit <= 0 {
return nil, errInvalidDirListerLimit
}
if len(l.cache) >= limit {
return l.returnFromCache(limit), nil
}
return l.returnFromCache(limit), io.EOF
}
func (l *baseDirLister) returnFromCache(limit int) []os.FileInfo {
if len(l.cache) >= limit {
result := l.cache[:limit]
l.cache = l.cache[limit:]
return result
}
result := l.cache
l.cache = nil
return result
}
func (l *baseDirLister) Close() error {
l.cache = nil
return nil
}
// QuotaCheckResult defines the result for a quota check
type QuotaCheckResult struct {
HasSpace bool
@ -1069,18 +1108,6 @@ func updateFileInfoModTime(storageID, objectPath string, info *FileInfo) (*FileI
return info, nil
}
func getFolderModTimes(storageID, dirName string) (map[string]int64, error) {
var err error
modTimes := make(map[string]int64)
if plugin.Handler.HasMetadater() {
modTimes, err = plugin.Handler.GetModificationTimes(storageID, ensureAbsPath(dirName))
if err != nil && !errors.Is(err, metadata.ErrNoSuchObject) {
return modTimes, err
}
}
return modTimes, nil
}
func ensureAbsPath(name string) string {
if path.IsAbs(name) {
return name
@ -1205,6 +1232,50 @@ func getLocalTempDir() string {
return filepath.Clean(os.TempDir())
}
func doRecursiveRename(fs Fs, source, target string,
renameFn func(string, string, os.FileInfo, int) (int, int64, error),
recursion int,
) (int, int64, error) {
var numFiles int
var filesSize int64
if recursion > util.MaxRecursion {
return numFiles, filesSize, util.ErrRecursionTooDeep
}
recursion++
lister, err := fs.ReadDir(source)
if err != nil {
return numFiles, filesSize, err
}
defer lister.Close()
for {
entries, err := lister.Next(ListerBatchSize)
finished := errors.Is(err, io.EOF)
if err != nil && !finished {
return numFiles, filesSize, err
}
for _, info := range entries {
sourceEntry := fs.Join(source, info.Name())
targetEntry := fs.Join(target, info.Name())
files, size, err := renameFn(sourceEntry, targetEntry, info, recursion)
if err != nil {
if fs.IsNotExist(err) {
fsLog(fs, logger.LevelInfo, "skipping rename for %q: %v", sourceEntry, err)
continue
}
return numFiles, filesSize, err
}
numFiles += files
filesSize += size
}
if finished {
return numFiles, filesSize, nil
}
}
}
func fsLog(fs Fs, level logger.LogLevel, format string, v ...any) {
logger.Log(level, fs.Name(), fs.ConnectionID(), format, v...)
}

View file

@ -108,22 +108,24 @@ func (fi *webDavFileInfo) ContentType(_ context.Context) (string, error) {
// Readdir reads directory entries from the handle
func (f *webDavFile) Readdir(_ int) ([]os.FileInfo, error) {
return nil, webdav.ErrNotImplemented
}
// ReadDir implements the FileDirLister interface
func (f *webDavFile) ReadDir() (webdav.DirLister, error) {
if !f.Connection.User.HasPerm(dataprovider.PermListItems, f.GetVirtualPath()) {
return nil, f.Connection.GetPermissionDeniedError()
}
entries, err := f.Connection.ListDir(f.GetVirtualPath())
lister, err := f.Connection.ListDir(f.GetVirtualPath())
if err != nil {
return nil, err
}
for idx, info := range entries {
entries[idx] = &webDavFileInfo{
FileInfo: info,
Fs: f.Fs,
virtualPath: path.Join(f.GetVirtualPath(), info.Name()),
fsPath: f.Fs.Join(f.GetFsPath(), info.Name()),
}
}
return entries, nil
return &webDavDirLister{
DirLister: lister,
fs: f.Fs,
virtualDirPath: f.GetVirtualPath(),
fsDirPath: f.GetFsPath(),
}, nil
}
// Stat the handle
@ -474,3 +476,24 @@ func (f *webDavFile) Patch(patches []webdav.Proppatch) ([]webdav.Propstat, error
}
return resp, nil
}
type webDavDirLister struct {
vfs.DirLister
fs vfs.Fs
virtualDirPath string
fsDirPath string
}
func (l *webDavDirLister) Next(limit int) ([]os.FileInfo, error) {
files, err := l.DirLister.Next(limit)
for idx := range files {
info := files[idx]
files[idx] = &webDavFileInfo{
FileInfo: info,
Fs: l.fs,
virtualPath: path.Join(l.virtualDirPath, info.Name()),
fsPath: l.fs.Join(l.fsDirPath, info.Name()),
}
}
return files, err
}

View file

@ -692,6 +692,8 @@ func TestContentType(t *testing.T) {
assert.Equal(t, "application/custom-mime", ctype)
}
_, err = davFile.Readdir(-1)
assert.ErrorIs(t, err, webdav.ErrNotImplemented)
_, err = davFile.ReadDir()
assert.Error(t, err)
err = davFile.Close()
assert.NoError(t, err)