diff --git a/go.mod b/go.mod index ba74673b..94aa471e 100644 --- a/go.mod +++ b/go.mod @@ -21,7 +21,7 @@ require ( github.com/bmatcuk/doublestar/v4 v4.6.1 github.com/cockroachdb/cockroach-go/v2 v2.3.6 github.com/coreos/go-oidc/v3 v3.9.0 - github.com/drakkan/webdav v0.0.0-20230227175313-32996838bcd8 + github.com/drakkan/webdav v0.0.0-20240212101318-94e905cb9adb github.com/eikenb/pipeat v0.0.0-20210730190139-06b3e6902001 github.com/fclairamb/ftpserverlib v0.22.0 github.com/fclairamb/go-log v0.4.1 @@ -71,7 +71,7 @@ require ( golang.org/x/crypto v0.18.0 golang.org/x/net v0.20.0 golang.org/x/oauth2 v0.16.0 - golang.org/x/sys v0.16.0 + golang.org/x/sys v0.17.0 golang.org/x/term v0.16.0 golang.org/x/time v0.5.0 google.golang.org/api v0.161.0 @@ -118,7 +118,9 @@ require ( github.com/google/s2a-go v0.1.7 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect github.com/googleapis/gax-go/v2 v2.12.0 // indirect + github.com/hashicorp/errwrap v1.0.0 // indirect github.com/hashicorp/go-cleanhttp v0.5.2 // indirect + github.com/hashicorp/go-multierror v1.1.1 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/hashicorp/yamux v0.1.1 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect @@ -181,8 +183,9 @@ require ( ) replace ( - github.com/fclairamb/ftpserverlib => github.com/drakkan/ftpserverlib v0.0.0-20230820193955-e7243edeb89b - github.com/jlaffaye/ftp => github.com/drakkan/ftp v0.0.0-20201114075148-9b9adce499a9 + github.com/fclairamb/ftpserverlib => github.com/drakkan/ftpserverlib v0.0.0-20240212100826-a241365cb085 + github.com/jlaffaye/ftp => github.com/drakkan/ftp v0.0.0-20240210102745-f1ffc43f78d2 + github.com/pkg/sftp => github.com/drakkan/sftp v0.0.0-20240214104840-fbb0b8bdb30c github.com/robfig/cron/v3 => github.com/drakkan/cron/v3 v3.0.0-20230222140221-217a1e4d96c0 golang.org/x/crypto => github.com/drakkan/crypto v0.0.0-20231218163632-74b52eafd2c0 ) diff --git a/go.sum b/go.sum index 8494987d..350be01b 100644 --- a/go.sum +++ b/go.sum @@ -113,12 +113,14 @@ github.com/drakkan/cron/v3 v3.0.0-20230222140221-217a1e4d96c0 h1:EW9gIJRmt9lzk66 github.com/drakkan/cron/v3 v3.0.0-20230222140221-217a1e4d96c0/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= github.com/drakkan/crypto v0.0.0-20231218163632-74b52eafd2c0 h1:Yel8NcrK4jg+biIcTxnszKh0eIpF2Vj25XEygQcTweI= github.com/drakkan/crypto v0.0.0-20231218163632-74b52eafd2c0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= -github.com/drakkan/ftp v0.0.0-20201114075148-9b9adce499a9 h1:LPH1dEblAOO/LoG7yHPMtBLXhQmjaga91/DDjWk9jWA= -github.com/drakkan/ftp v0.0.0-20201114075148-9b9adce499a9/go.mod h1:2lmrmq866uF2tnje75wQHzmPXhmSWUt7Gyx2vgK1RCU= -github.com/drakkan/ftpserverlib v0.0.0-20230820193955-e7243edeb89b h1:sCtiYerLxfOQrSludkwGwwXLlSVHxpvfmyOxjCOf0ec= -github.com/drakkan/ftpserverlib v0.0.0-20230820193955-e7243edeb89b/go.mod h1:dI9/yw/KfJ0g4wmRK8ZukUfqakLr6ZTf9VDydKoLy90= -github.com/drakkan/webdav v0.0.0-20230227175313-32996838bcd8 h1:tdkLkSKtYd3WSDsZXGJDKsakiNstLQJPN5HjnqCkf2c= -github.com/drakkan/webdav v0.0.0-20230227175313-32996838bcd8/go.mod h1:zOVb1QDhwwqWn2L2qZ0U3swMSO4GTSNyIwXCGO/UGWE= +github.com/drakkan/ftp v0.0.0-20240210102745-f1ffc43f78d2 h1:ufiGMPFBjndWSQOst9FNP11IuMqPblI2NXbpRMUWNhk= +github.com/drakkan/ftp v0.0.0-20240210102745-f1ffc43f78d2/go.mod h1:4p8lUl4vQ80L598CygL+3IFtm+3nggvvW/palOlViwE= +github.com/drakkan/ftpserverlib v0.0.0-20240212100826-a241365cb085 h1:LAKYR9z9USKeyEQK91sRWldmMOjEHLOt2NuLDx+x1UQ= +github.com/drakkan/ftpserverlib v0.0.0-20240212100826-a241365cb085/go.mod h1:9rZ27KBV3xlXmjIfd6HynND28tse8ShZJ/NQkprCKno= +github.com/drakkan/sftp v0.0.0-20240214104840-fbb0b8bdb30c h1:usPo/2W6Dj2rugQiEml0pwmUfY/wUgW6nLGl+q98c5k= +github.com/drakkan/sftp v0.0.0-20240214104840-fbb0b8bdb30c/go.mod h1:KMKI0t3T6hfA+lTR/ssZdunHo+uwq7ghoN09/FSu3DY= +github.com/drakkan/webdav v0.0.0-20240212101318-94e905cb9adb h1:BLT+1m0U57PevUtACnTUoMErsLJyK2ydeNiVuX8R3Lk= +github.com/drakkan/webdav v0.0.0-20240212101318-94e905cb9adb/go.mod h1:zOVb1QDhwwqWn2L2qZ0U3swMSO4GTSNyIwXCGO/UGWE= github.com/eikenb/pipeat v0.0.0-20210730190139-06b3e6902001 h1:/ZshrfQzayqRSBDodmp3rhNCHJCff+utvgBuWRbiqu4= github.com/eikenb/pipeat v0.0.0-20210730190139-06b3e6902001/go.mod h1:kltMsfRMTHSFdMbK66XdS8mfMW77+FZA1fGY1xYMF84= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= @@ -216,11 +218,15 @@ github.com/googleapis/enterprise-certificate-proxy v0.3.2 h1:Vie5ybvEvT75RniqhfF github.com/googleapis/enterprise-certificate-proxy v0.3.2/go.mod h1:VLSiSSBs/ksPL8kq3OBOQ6WRI2QnaFynd1DCjZ62+V0= github.com/googleapis/gax-go/v2 v2.12.0 h1:A+gCJKdRfqXkr+BIRGtZLibNXf0m1f9E4HG56etFpas= github.com/googleapis/gax-go/v2 v2.12.0/go.mod h1:y+aIqrI5eb1YGMVJfuV3185Ts/D7qKpsEkdD5+I6QGU= +github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= +github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48= github.com/hashicorp/go-hclog v0.9.2/go.mod h1:5CU+agLiy3J7N7QjHK5d05KxGsuXiQLrjA0H7acj2lQ= github.com/hashicorp/go-hclog v1.6.2 h1:NOtoftovWkDheyUM/8JW3QMiXyxJK3uHRK7wV04nD2I= github.com/hashicorp/go-hclog v1.6.2/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M= +github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= +github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= github.com/hashicorp/go-plugin v1.6.0 h1:wgd4KxHJTVGGqWBq4QPB1i5BZNEx9BR8+OFmHDmTk8A= github.com/hashicorp/go-plugin v1.6.0/go.mod h1:lBS5MtSSBZk0SHc66KACcjjlU6WzEVP/8pwz68aMkCI= github.com/hashicorp/go-retryablehttp v0.7.5 h1:bJj+Pj19UZMIweq/iie+1u5YCdGrnxCT9yvm0e+Nd5M= @@ -311,8 +317,6 @@ github.com/pires/go-proxyproto v0.7.0/go.mod h1:Vz/1JPY/OACxWGQNIRY2BeyDmpoaWmEP github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 h1:KoWmjvw+nsYOo29YJK9vDA65RGE3NrOnUtO7a+RF9HU= github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8/go.mod h1:HKlIX3XHQyzLZPlr7++PzdhaXEj94dEiJgZDTsxEqUI= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -github.com/pkg/sftp v1.13.6 h1:JFZT4XbOU7l77xGSpOdW+pwIMqP044IyjXX6FGyEKFo= -github.com/pkg/sftp v1.13.6/go.mod h1:tz1ryNURKu77RL+GuCzmoJYxQczL3wLNNpPWagdg4Qk= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -487,8 +491,9 @@ golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU= golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y= +golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= diff --git a/internal/common/connection.go b/internal/common/connection.go index de3dda92..a8374f64 100644 --- a/internal/common/connection.go +++ b/internal/common/connection.go @@ -297,7 +297,7 @@ func (c *BaseConnection) truncateOpenHandle(fsPath string, size int64) (int64, e } // ListDir reads the directory matching virtualPath and returns a list of directory entries -func (c *BaseConnection) ListDir(virtualPath string) ([]os.FileInfo, error) { +func (c *BaseConnection) ListDir(virtualPath string) (*DirListerAt, error) { if !c.User.HasPerm(dataprovider.PermListItems, virtualPath) { return nil, c.GetPermissionDeniedError() } @@ -305,12 +305,17 @@ func (c *BaseConnection) ListDir(virtualPath string) ([]os.FileInfo, error) { if err != nil { return nil, err } - files, err := fs.ReadDir(fsPath) + lister, err := fs.ReadDir(fsPath) if err != nil { c.Log(logger.LevelDebug, "error listing directory: %+v", err) return nil, c.GetFsError(fs, err) } - return c.User.FilterListDir(files, virtualPath), nil + return &DirListerAt{ + virtualPath: virtualPath, + user: &c.User, + info: c.User.GetVirtualFoldersInfo(virtualPath), + lister: lister, + }, nil } // CheckParentDirs tries to create the specified directory and any missing parent dirs @@ -511,24 +516,42 @@ func (c *BaseConnection) RemoveDir(virtualPath string) error { return nil } -func (c *BaseConnection) doRecursiveRemoveDirEntry(virtualPath string, info os.FileInfo) error { +func (c *BaseConnection) doRecursiveRemoveDirEntry(virtualPath string, info os.FileInfo, recursion int) error { fs, fsPath, err := c.GetFsAndResolvedPath(virtualPath) if err != nil { return err } - return c.doRecursiveRemove(fs, fsPath, virtualPath, info) + return c.doRecursiveRemove(fs, fsPath, virtualPath, info, recursion) } -func (c *BaseConnection) doRecursiveRemove(fs vfs.Fs, fsPath, virtualPath string, info os.FileInfo) error { +func (c *BaseConnection) doRecursiveRemove(fs vfs.Fs, fsPath, virtualPath string, info os.FileInfo, recursion int) error { if info.IsDir() { - entries, err := c.ListDir(virtualPath) - if err != nil { - return fmt.Errorf("unable to get contents for dir %q: %w", virtualPath, err) + if recursion >= util.MaxRecursion { + c.Log(logger.LevelError, "recursive rename failed, recursion too depth: %d", recursion) + return util.ErrRecursionTooDeep } - for _, fi := range entries { - targetPath := path.Join(virtualPath, fi.Name()) - if err := c.doRecursiveRemoveDirEntry(targetPath, fi); err != nil { - return err + recursion++ + lister, err := c.ListDir(virtualPath) + if err != nil { + return fmt.Errorf("unable to get lister for dir %q: %w", virtualPath, err) + } + defer lister.Close() + + for { + entries, err := lister.Next(vfs.ListerBatchSize) + finished := errors.Is(err, io.EOF) + if err != nil && !finished { + return fmt.Errorf("unable to get content for dir %q: %w", virtualPath, err) + } + for _, fi := range entries { + targetPath := path.Join(virtualPath, fi.Name()) + if err := c.doRecursiveRemoveDirEntry(targetPath, fi, recursion); err != nil { + return err + } + } + if finished { + lister.Close() + break } } return c.RemoveDir(virtualPath) @@ -552,7 +575,7 @@ func (c *BaseConnection) RemoveAll(virtualPath string) error { if err := c.IsRemoveDirAllowed(fs, fsPath, virtualPath); err != nil { return err } - return c.doRecursiveRemove(fs, fsPath, virtualPath, fi) + return c.doRecursiveRemove(fs, fsPath, virtualPath, fi, 0) } return c.RemoveFile(fs, fsPath, virtualPath, fi) } @@ -626,43 +649,38 @@ func (c *BaseConnection) copyFile(virtualSourcePath, virtualTargetPath string, s } func (c *BaseConnection) doRecursiveCopy(virtualSourcePath, virtualTargetPath string, srcInfo os.FileInfo, - createTargetDir bool, + createTargetDir bool, recursion int, ) error { if srcInfo.IsDir() { + if recursion >= util.MaxRecursion { + c.Log(logger.LevelError, "recursive copy failed, recursion too depth: %d", recursion) + return util.ErrRecursionTooDeep + } + recursion++ if createTargetDir { if err := c.CreateDir(virtualTargetPath, false); err != nil { return fmt.Errorf("unable to create directory %q: %w", virtualTargetPath, err) } } - entries, err := c.ListDir(virtualSourcePath) + lister, err := c.ListDir(virtualSourcePath) if err != nil { - return fmt.Errorf("unable to get contents for dir %q: %w", virtualSourcePath, err) + return fmt.Errorf("unable to get lister for dir %q: %w", virtualSourcePath, err) } - for _, info := range entries { - sourcePath := path.Join(virtualSourcePath, info.Name()) - targetPath := path.Join(virtualTargetPath, info.Name()) - targetInfo, err := c.DoStat(targetPath, 1, false) - if err == nil { - if info.IsDir() && targetInfo.IsDir() { - c.Log(logger.LevelDebug, "target copy dir %q already exists", targetPath) - continue - } + defer lister.Close() + + for { + entries, err := lister.Next(vfs.ListerBatchSize) + finished := errors.Is(err, io.EOF) + if err != nil && !finished { + return fmt.Errorf("unable to get contents for dir %q: %w", virtualSourcePath, err) } - if err != nil && !c.IsNotExistError(err) { + if err := c.recursiveCopyEntries(virtualSourcePath, virtualTargetPath, entries, recursion); err != nil { return err } - if err := c.checkCopy(info, targetInfo, sourcePath, targetPath); err != nil { - return err - } - if err := c.doRecursiveCopy(sourcePath, targetPath, info, true); err != nil { - if c.IsNotExistError(err) { - c.Log(logger.LevelInfo, "skipping copy for source path %q: %v", sourcePath, err) - continue - } - return err + if finished { + return nil } } - return nil } if !srcInfo.Mode().IsRegular() { c.Log(logger.LevelInfo, "skipping copy for non regular file %q", virtualSourcePath) @@ -672,6 +690,34 @@ func (c *BaseConnection) doRecursiveCopy(virtualSourcePath, virtualTargetPath st return c.copyFile(virtualSourcePath, virtualTargetPath, srcInfo.Size()) } +func (c *BaseConnection) recursiveCopyEntries(virtualSourcePath, virtualTargetPath string, entries []os.FileInfo, recursion int) error { + for _, info := range entries { + sourcePath := path.Join(virtualSourcePath, info.Name()) + targetPath := path.Join(virtualTargetPath, info.Name()) + targetInfo, err := c.DoStat(targetPath, 1, false) + if err == nil { + if info.IsDir() && targetInfo.IsDir() { + c.Log(logger.LevelDebug, "target copy dir %q already exists", targetPath) + continue + } + } + if err != nil && !c.IsNotExistError(err) { + return err + } + if err := c.checkCopy(info, targetInfo, sourcePath, targetPath); err != nil { + return err + } + if err := c.doRecursiveCopy(sourcePath, targetPath, info, true, recursion); err != nil { + if c.IsNotExistError(err) { + c.Log(logger.LevelInfo, "skipping copy for source path %q: %v", sourcePath, err) + continue + } + return err + } + } + return nil +} + // Copy virtualSourcePath to virtualTargetPath func (c *BaseConnection) Copy(virtualSourcePath, virtualTargetPath string) error { copyFromSource := strings.HasSuffix(virtualSourcePath, "/") @@ -717,7 +763,7 @@ func (c *BaseConnection) Copy(virtualSourcePath, virtualTargetPath string) error defer close(done) go keepConnectionAlive(c, done, 2*time.Minute) - return c.doRecursiveCopy(virtualSourcePath, destPath, srcInfo, createTargetDir) + return c.doRecursiveCopy(virtualSourcePath, destPath, srcInfo, createTargetDir, 0) } // Rename renames (moves) virtualSourcePath to virtualTargetPath @@ -865,7 +911,8 @@ func (c *BaseConnection) doStatInternal(virtualPath string, mode int, checkFileP convertResult bool, ) (os.FileInfo, error) { // for some vfs we don't create intermediary folders so we cannot simply check - // if virtualPath is a virtual folder + // if virtualPath is a virtual folder. Allowing stat for hidden virtual folders + // is by purpose. vfolders := c.User.GetVirtualFoldersInPath(path.Dir(virtualPath)) if _, ok := vfolders[virtualPath]; ok { return vfs.NewFileInfo(virtualPath, true, 0, time.Unix(0, 0), false), nil @@ -1739,6 +1786,83 @@ func (c *BaseConnection) GetFsAndResolvedPath(virtualPath string) (vfs.Fs, strin return fs, fsPath, nil } +// DirListerAt defines a directory lister implementing the ListAt method. +type DirListerAt struct { + virtualPath string + user *dataprovider.User + info []os.FileInfo + mu sync.Mutex + lister vfs.DirLister +} + +// Add adds the given os.FileInfo to the internal cache +func (l *DirListerAt) Add(fi os.FileInfo) { + l.mu.Lock() + defer l.mu.Unlock() + + l.info = append(l.info, fi) +} + +// ListAt implements sftp.ListerAt +func (l *DirListerAt) ListAt(f []os.FileInfo, _ int64) (int, error) { + l.mu.Lock() + defer l.mu.Unlock() + + if len(f) == 0 { + return 0, errors.New("invalid ListAt destination, zero size") + } + if len(f) <= len(l.info) { + files := make([]os.FileInfo, 0, len(f)) + for idx := len(l.info) - 1; idx >= 0; idx-- { + files = append(files, l.info[idx]) + if len(files) == len(f) { + l.info = l.info[:idx] + n := copy(f, files) + return n, nil + } + } + } + limit := len(f) - len(l.info) + files, err := l.Next(limit) + n := copy(f, files) + return n, err +} + +// Next reads the directory and returns a slice of up to n FileInfo values. +func (l *DirListerAt) Next(limit int) ([]os.FileInfo, error) { + for { + files, err := l.lister.Next(limit) + if err != nil && !errors.Is(err, io.EOF) { + return files, err + } + files = l.user.FilterListDir(files, l.virtualPath) + if len(l.info) > 0 { + for _, fi := range l.info { + files = util.PrependFileInfo(files, fi) + } + l.info = nil + } + if err != nil || len(files) > 0 { + return files, err + } + } +} + +// Close closes the DirListerAt +func (l *DirListerAt) Close() error { + l.mu.Lock() + defer l.mu.Unlock() + + return l.lister.Close() +} + +func (l *DirListerAt) convertError(err error) error { + if errors.Is(err, io.EOF) { + return nil + } + return err +} + func getPermissionDeniedError(protocol string) error { switch protocol { case ProtocolSFTP: diff --git a/internal/common/connection_test.go b/internal/common/connection_test.go index bee6d7fa..d7e1b904 100644 --- a/internal/common/connection_test.go +++ b/internal/common/connection_test.go @@ -17,10 +17,12 @@ package common import ( "errors" "fmt" + "io" "os" "path" "path/filepath" "runtime" + "strconv" "testing" "time" @@ -601,8 +603,10 @@ func TestErrorResolvePath(t *testing.T) { } conn := NewBaseConnection("", ProtocolSFTP, "", "", u) - err := conn.doRecursiveRemoveDirEntry("/vpath", nil) + err := conn.doRecursiveRemoveDirEntry("/vpath", nil, 0) assert.Error(t, err) + err = conn.doRecursiveRemove(nil, "/fspath", "/vpath", vfs.NewFileInfo("vpath", true, 0, time.Now(), false), 2000) + assert.Error(t, err, util.ErrRecursionTooDeep) err = conn.checkCopy(vfs.NewFileInfo("name", true, 0, time.Unix(0, 0), false), nil, "/source", "/target") assert.Error(t, err) sourceFile := filepath.Join(os.TempDir(), "f", "source") @@ -700,26 +704,32 @@ func TestFilePatterns(t *testing.T) { VirtualFolders: virtualFolders, } + getFilteredInfo := func(dirContents []os.FileInfo, virtualPath string) []os.FileInfo { + result := user.FilterListDir(dirContents, virtualPath) + result = append(result, user.GetVirtualFoldersInfo(virtualPath)...) + return result + } + dirContents := []os.FileInfo{ vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false), } // dirContents are modified in place, we need to redefine them each time - filtered := user.FilterListDir(dirContents, "/dir1") + filtered := getFilteredInfo(dirContents, "/dir1") assert.Len(t, filtered, 5) dirContents = []os.FileInfo{ vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false), } - filtered = user.FilterListDir(dirContents, "/dir1/vdir1") + filtered = getFilteredInfo(dirContents, "/dir1/vdir1") assert.Len(t, filtered, 2) dirContents = []os.FileInfo{ vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false), } - filtered = user.FilterListDir(dirContents, "/dir2/vdir2") + filtered = getFilteredInfo(dirContents, "/dir2/vdir2") require.Len(t, filtered, 1) assert.Equal(t, "file1.jpg", filtered[0].Name()) @@ -727,7 +737,7 @@ func TestFilePatterns(t *testing.T) { vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false), } - filtered = user.FilterListDir(dirContents, "/dir2/vdir2/sub") + filtered = getFilteredInfo(dirContents, "/dir2/vdir2/sub") require.Len(t, filtered, 1) assert.Equal(t, "file1.jpg", filtered[0].Name()) @@ -754,14 +764,14 @@ func TestFilePatterns(t *testing.T) { vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false), } - filtered = user.FilterListDir(dirContents, "/dir4") + filtered = getFilteredInfo(dirContents, "/dir4") require.Len(t, filtered, 0) dirContents = []os.FileInfo{ vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false), } - filtered = user.FilterListDir(dirContents, "/dir4/vdir2/sub") + filtered = getFilteredInfo(dirContents, "/dir4/vdir2/sub") require.Len(t, filtered, 0) dirContents = []os.FileInfo{ @@ -769,7 +779,7 @@ func TestFilePatterns(t *testing.T) { vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false), } - filtered = user.FilterListDir(dirContents, "/dir2") + filtered = getFilteredInfo(dirContents, "/dir2") assert.Len(t, filtered, 2) dirContents = []os.FileInfo{ @@ -777,7 +787,7 @@ func TestFilePatterns(t *testing.T) { vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false), } - filtered = user.FilterListDir(dirContents, "/dir4") + filtered = getFilteredInfo(dirContents, "/dir4") assert.Len(t, filtered, 0) dirContents = []os.FileInfo{ @@ -785,7 +795,7 @@ func TestFilePatterns(t *testing.T) { vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false), } - filtered = user.FilterListDir(dirContents, "/dir4/sub") + filtered = getFilteredInfo(dirContents, "/dir4/sub") assert.Len(t, filtered, 0) dirContents = []os.FileInfo{ @@ -793,10 +803,10 @@ func TestFilePatterns(t *testing.T) { vfs.NewFileInfo("vdir3.jpg", false, 123, time.Now(), false), } - filtered = user.FilterListDir(dirContents, "/dir1") + filtered = getFilteredInfo(dirContents, "/dir1") assert.Len(t, filtered, 5) - filtered = user.FilterListDir(dirContents, "/dir2") + filtered = getFilteredInfo(dirContents, "/dir2") if assert.Len(t, filtered, 1) { assert.True(t, filtered[0].IsDir()) } @@ -806,14 +816,14 @@ func TestFilePatterns(t *testing.T) { vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), vfs.NewFileInfo("vdir3.jpg", false, 123, time.Now(), false), } - filtered = user.FilterListDir(dirContents, "/dir1") + filtered = getFilteredInfo(dirContents, "/dir1") assert.Len(t, filtered, 2) dirContents = []os.FileInfo{ vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), vfs.NewFileInfo("vdir3.jpg", false, 123, time.Now(), false), } - filtered = user.FilterListDir(dirContents, "/dir2") + filtered = getFilteredInfo(dirContents, "/dir2") if assert.Len(t, filtered, 1) { assert.False(t, filtered[0].IsDir()) } @@ -824,7 +834,7 @@ func TestFilePatterns(t *testing.T) { vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false), vfs.NewFileInfo("vdir3.jpg", false, 123, time.Now(), false), } - filtered = user.FilterListDir(dirContents, "/dir2") + filtered = getFilteredInfo(dirContents, "/dir2") if assert.Len(t, filtered, 2) { assert.False(t, filtered[0].IsDir()) assert.False(t, filtered[1].IsDir()) @@ -832,9 +842,9 @@ func TestFilePatterns(t *testing.T) { user.VirtualFolders = virtualFolders user.Filters = filters - filtered = user.FilterListDir(nil, "/dir1") + filtered = getFilteredInfo(nil, "/dir1") assert.Len(t, filtered, 3) - filtered = user.FilterListDir(nil, "/dir2") + filtered = getFilteredInfo(nil, "/dir2") assert.Len(t, filtered, 1) dirContents = []os.FileInfo{ @@ -843,7 +853,7 @@ func TestFilePatterns(t *testing.T) { vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false), vfs.NewFileInfo("vdir3.jpg", false, 456, time.Now(), false), } - filtered = user.FilterListDir(dirContents, "/dir2") + filtered = getFilteredInfo(dirContents, "/dir2") assert.Len(t, filtered, 2) user = dataprovider.User{ @@ -866,7 +876,7 @@ func TestFilePatterns(t *testing.T) { vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false), vfs.NewFileInfo("vdir3.jpg", false, 456, time.Now(), false), } - filtered = user.FilterListDir(dirContents, "/dir3") + filtered = getFilteredInfo(dirContents, "/dir3") assert.Len(t, filtered, 0) dirContents = nil @@ -881,7 +891,7 @@ func TestFilePatterns(t *testing.T) { dirContents = append(dirContents, vfs.NewFileInfo("ic35.*", false, 123, time.Now(), false)) dirContents = append(dirContents, vfs.NewFileInfo("file.jpg", false, 123, time.Now(), false)) - filtered = user.FilterListDir(dirContents, "/dir3") + filtered = getFilteredInfo(dirContents, "/dir3") require.Len(t, filtered, 1) assert.Equal(t, "ic35", filtered[0].Name()) @@ -890,7 +900,7 @@ func TestFilePatterns(t *testing.T) { vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false), } - filtered = user.FilterListDir(dirContents, "/dir3/ic36") + filtered = getFilteredInfo(dirContents, "/dir3/ic36") require.Len(t, filtered, 0) dirContents = []os.FileInfo{ @@ -898,7 +908,7 @@ func TestFilePatterns(t *testing.T) { vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false), } - filtered = user.FilterListDir(dirContents, "/dir3/ic35") + filtered = getFilteredInfo(dirContents, "/dir3/ic35") require.Len(t, filtered, 3) dirContents = []os.FileInfo{ @@ -906,7 +916,7 @@ func TestFilePatterns(t *testing.T) { vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false), } - filtered = user.FilterListDir(dirContents, "/dir3/ic35/sub") + filtered = getFilteredInfo(dirContents, "/dir3/ic35/sub") require.Len(t, filtered, 3) res, _ = user.IsFileAllowed("/dir3/file.txt") @@ -930,7 +940,7 @@ func TestFilePatterns(t *testing.T) { vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false), } - filtered = user.FilterListDir(dirContents, "/dir3/ic35/sub") + filtered = getFilteredInfo(dirContents, "/dir3/ic35/sub") require.Len(t, filtered, 3) user.Filters.FilePatterns = append(user.Filters.FilePatterns, sdk.PatternsFilter{ @@ -949,7 +959,7 @@ func TestFilePatterns(t *testing.T) { vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false), } - filtered = user.FilterListDir(dirContents, "/dir3/ic35/sub1") + filtered = getFilteredInfo(dirContents, "/dir3/ic35/sub1") require.Len(t, filtered, 3) dirContents = []os.FileInfo{ @@ -957,7 +967,7 @@ func TestFilePatterns(t *testing.T) { vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false), } - filtered = user.FilterListDir(dirContents, "/dir3/ic35/sub2") + filtered = getFilteredInfo(dirContents, "/dir3/ic35/sub2") require.Len(t, filtered, 2) dirContents = []os.FileInfo{ @@ -965,7 +975,7 @@ func TestFilePatterns(t *testing.T) { vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false), } - filtered = user.FilterListDir(dirContents, "/dir3/ic35/sub2/sub1") + filtered = getFilteredInfo(dirContents, "/dir3/ic35/sub2/sub1") require.Len(t, filtered, 2) res, _ = user.IsFileAllowed("/dir3/ic35/file.jpg") @@ -1023,7 +1033,7 @@ func TestFilePatterns(t *testing.T) { vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false), } - filtered = user.FilterListDir(dirContents, "/dir3/ic35") + filtered = getFilteredInfo(dirContents, "/dir3/ic35") require.Len(t, filtered, 1) dirContents = []os.FileInfo{ @@ -1031,6 +1041,116 @@ func TestFilePatterns(t *testing.T) { vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false), } - filtered = user.FilterListDir(dirContents, "/dir3/ic35/abc") + filtered = getFilteredInfo(dirContents, "/dir3/ic35/abc") require.Len(t, filtered, 1) } + +func TestListerAt(t *testing.T) { + dir := t.TempDir() + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: "u", + Password: "p", + HomeDir: dir, + Status: 1, + Permissions: map[string][]string{ + "/": {"*"}, + }, + }, + } + conn := NewBaseConnection(xid.New().String(), ProtocolSFTP, "", "", user) + lister, err := conn.ListDir("/") + require.NoError(t, err) + files, err := lister.Next(1) + require.ErrorIs(t, err, io.EOF) + require.Len(t, files, 0) + err = lister.Close() + require.NoError(t, err) + + conn.User.VirtualFolders = []vfs.VirtualFolder{ + { + VirtualPath: "p1", + }, + { + VirtualPath: "p2", + }, + { + VirtualPath: "p3", + }, + } + lister, err = conn.ListDir("/") + require.NoError(t, err) + files, err = lister.Next(2) + // virtual directories exceeds the limit + require.ErrorIs(t, err, io.EOF) + require.Len(t, files, 3) + files, err = lister.Next(2) + require.ErrorIs(t, err, io.EOF) + require.Len(t, files, 0) + _, err = lister.Next(-1) + require.ErrorContains(t, err, "invalid limit") + err = lister.Close() + require.NoError(t, err) + + lister, err = conn.ListDir("/") + require.NoError(t, err) + _, err = lister.ListAt(nil, 0) + require.ErrorContains(t, err, "zero size") + err = lister.Close() + require.NoError(t, err) + + for i := 0; i < 100; i++ { + f, err := os.Create(filepath.Join(dir, strconv.Itoa(i))) + require.NoError(t, err) + err = f.Close() + require.NoError(t, err) + } + lister, err = conn.ListDir("/") + require.NoError(t, err) + files = make([]os.FileInfo, 18) + n, err := lister.ListAt(files, 0) + require.NoError(t, err) + require.Equal(t, 18, n) + n, err = lister.ListAt(files, 0) + require.NoError(t, err) + require.Equal(t, 18, n) + files = make([]os.FileInfo, 100) + n, err = lister.ListAt(files, 0) + require.NoError(t, err) + require.Equal(t, 64+3, n) + n, err = lister.ListAt(files, 0) + require.ErrorIs(t, err, io.EOF) + require.Equal(t, 0, n) + n, err = lister.ListAt(files, 0) + require.ErrorIs(t, err, io.EOF) + require.Equal(t, 0, n) + err = lister.Close() + require.NoError(t, err) + n, err = lister.ListAt(files, 0) + require.Error(t, err) + require.NotErrorIs(t, err, io.EOF) + require.Equal(t, 0, n) + lister, err = conn.ListDir("/") + require.NoError(t, err) + lister.Add(vfs.NewFileInfo("..", true, 0, time.Unix(0, 0), false)) + lister.Add(vfs.NewFileInfo(".", true, 0, time.Unix(0, 0), false)) + files = make([]os.FileInfo, 1) + n, err = lister.ListAt(files, 0) + require.NoError(t, err) + require.Equal(t, 1, n) + assert.Equal(t, ".", files[0].Name()) + files = make([]os.FileInfo, 2) + n, err = lister.ListAt(files, 0) + require.NoError(t, err) + require.Equal(t, 2, n) + assert.Equal(t, "..", files[0].Name()) + assert.Equal(t, "p3", files[1].Name()) + files = make([]os.FileInfo, 200) + n, err = lister.ListAt(files, 0) + require.NoError(t, err) + require.Equal(t, 102, n) + assert.Equal(t, "p2", files[0].Name()) + assert.Equal(t, "p1", files[1].Name()) + err = lister.Close() + require.NoError(t, err) +} diff --git a/internal/common/dataretention.go b/internal/common/dataretention.go index e5e95d13..f2f6c21a 100644 --- a/internal/common/dataretention.go +++ b/internal/common/dataretention.go @@ -18,7 +18,9 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" + "io" "net/http" "net/url" "os" @@ -37,6 +39,7 @@ import ( "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/smtp" "github.com/drakkan/sftpgo/v2/internal/util" + "github.com/drakkan/sftpgo/v2/internal/vfs" ) // RetentionCheckNotification defines the supported notification methods for a retention check result @@ -226,8 +229,17 @@ func (c *RetentionCheck) removeFile(virtualPath string, info os.FileInfo) error return c.conn.RemoveFile(fs, fsPath, virtualPath, info) } -func (c *RetentionCheck) cleanupFolder(folderPath string) error { - deleteFilesPerms := []string{dataprovider.PermDelete, dataprovider.PermDeleteFiles} +func (c *RetentionCheck) hasCleanupPerms(folderPath string) bool { + if !c.conn.User.HasPerm(dataprovider.PermListItems, folderPath) { + return false + } + if !c.conn.User.HasAnyPerm([]string{dataprovider.PermDelete, dataprovider.PermDeleteFiles}, folderPath) { + return false + } + return true +} + +func (c *RetentionCheck) cleanupFolder(folderPath string, recursion int) error { startTime := time.Now() result := folderRetentionCheckResult{ Path: folderPath, @@ -235,7 +247,15 @@ func (c *RetentionCheck) cleanupFolder(folderPath string) error { defer func() { c.results = append(c.results, result) }() - if !c.conn.User.HasPerm(dataprovider.PermListItems, folderPath) || !c.conn.User.HasAnyPerm(deleteFilesPerms, folderPath) { + if recursion >= util.MaxRecursion { + result.Elapsed = time.Since(startTime) + result.Info = "data retention check skipped: recursion too deep" + c.conn.Log(logger.LevelError, "data retention check skipped, recursion too depth for %q: %d", + folderPath, recursion) + return util.ErrRecursionTooDeep + } + recursion++ + if !c.hasCleanupPerms(folderPath) { result.Elapsed = time.Since(startTime) result.Info = "data retention check skipped: no permissions" c.conn.Log(logger.LevelInfo, "user %q does not have permissions to check retention on %q, retention check skipped", @@ -259,7 +279,7 @@ func (c *RetentionCheck) cleanupFolder(folderPath string) error { } c.conn.Log(logger.LevelDebug, "start retention check for folder %q, retention: %v hours, delete empty dirs? %v, ignore user perms? %v", folderPath, folderRetention.Retention, folderRetention.DeleteEmptyDirs, folderRetention.IgnoreUserPermissions) - files, err := c.conn.ListDir(folderPath) + lister, err := c.conn.ListDir(folderPath) if err != nil { result.Elapsed = time.Since(startTime) if err == c.conn.GetNotExistError() { @@ -267,40 +287,54 @@ func (c *RetentionCheck) cleanupFolder(folderPath string) error { c.conn.Log(logger.LevelDebug, "folder %q does not exist, retention check skipped", folderPath) return nil } - result.Error = fmt.Sprintf("unable to list directory %q", folderPath) + result.Error = fmt.Sprintf("unable to get lister for directory %q", folderPath) c.conn.Log(logger.LevelError, result.Error) return err } - for _, info := range files { - virtualPath := path.Join(folderPath, info.Name()) - if info.IsDir() { - if err := c.cleanupFolder(virtualPath); err != nil { - result.Elapsed = time.Since(startTime) - result.Error = fmt.Sprintf("unable to check folder: %v", err) - c.conn.Log(logger.LevelError, "unable to cleanup folder %q: %v", virtualPath, err) - return err - } - } else { - retentionTime := info.ModTime().Add(time.Duration(folderRetention.Retention) * time.Hour) - if retentionTime.Before(time.Now()) { - if err := c.removeFile(virtualPath, info); err != nil { + defer lister.Close() + + for { + files, err := lister.Next(vfs.ListerBatchSize) + finished := errors.Is(err, io.EOF) + if err := lister.convertError(err); err != nil { + result.Elapsed = time.Since(startTime) + result.Error = fmt.Sprintf("unable to list directory %q", folderPath) + c.conn.Log(logger.LevelError, "unable to list dir %q: %v", folderPath, err) + return err + } + for _, info := range files { + virtualPath := path.Join(folderPath, info.Name()) + if info.IsDir() { + if err := c.cleanupFolder(virtualPath, recursion); err != nil { result.Elapsed = time.Since(startTime) - result.Error = fmt.Sprintf("unable to remove file %q: %v", virtualPath, err) - c.conn.Log(logger.LevelError, "unable to remove file %q, retention %v: %v", - virtualPath, retentionTime, err) + result.Error = fmt.Sprintf("unable to check folder: %v", err) + c.conn.Log(logger.LevelError, "unable to cleanup folder %q: %v", virtualPath, err) return err } - c.conn.Log(logger.LevelDebug, "removed file %q, modification time: %v, retention: %v hours, retention time: %v", - virtualPath, info.ModTime(), folderRetention.Retention, retentionTime) - result.DeletedFiles++ - result.DeletedSize += info.Size() + } else { + retentionTime := info.ModTime().Add(time.Duration(folderRetention.Retention) * time.Hour) + if retentionTime.Before(time.Now()) { + if err := c.removeFile(virtualPath, info); err != nil { + result.Elapsed = time.Since(startTime) + result.Error = fmt.Sprintf("unable to remove file %q: %v", virtualPath, err) + c.conn.Log(logger.LevelError, "unable to remove file %q, retention %v: %v", + virtualPath, retentionTime, err) + return err + } + c.conn.Log(logger.LevelDebug, "removed file %q, modification time: %v, retention: %v hours, retention time: %v", + virtualPath, info.ModTime(), folderRetention.Retention, retentionTime) + result.DeletedFiles++ + result.DeletedSize += info.Size() + } } } + if finished { + break + } } - if folderRetention.DeleteEmptyDirs { - c.checkEmptyDirRemoval(folderPath) - } + lister.Close() + c.checkEmptyDirRemoval(folderPath, folderRetention.DeleteEmptyDirs) result.Elapsed = time.Since(startTime) c.conn.Log(logger.LevelDebug, "retention check completed for folder %q, deleted files: %v, deleted size: %v bytes", folderPath, result.DeletedFiles, result.DeletedSize) @@ -308,8 +342,8 @@ func (c *RetentionCheck) cleanupFolder(folderPath string) error { return nil } -func (c *RetentionCheck) checkEmptyDirRemoval(folderPath string) { - if folderPath == "/" { +func (c *RetentionCheck) checkEmptyDirRemoval(folderPath string, checkVal bool) { + if folderPath == "/" || !checkVal { return } for _, folder := range c.Folders { @@ -322,10 +356,14 @@ func (c *RetentionCheck) checkEmptyDirRemoval(folderPath string) { dataprovider.PermDeleteDirs, }, path.Dir(folderPath), ) { - files, err := c.conn.ListDir(folderPath) - if err == nil && len(files) == 0 { - err = c.conn.RemoveDir(folderPath) - c.conn.Log(logger.LevelDebug, "tried to remove empty dir %q, error: %v", folderPath, err) + lister, err := c.conn.ListDir(folderPath) + if err == nil { + files, err := lister.Next(1) + lister.Close() + if len(files) == 0 && errors.Is(err, io.EOF) { + err = c.conn.RemoveDir(folderPath) + c.conn.Log(logger.LevelDebug, "tried to remove empty dir %q, error: %v", folderPath, err) + } } } } @@ -339,7 +377,7 @@ func (c *RetentionCheck) Start() error { startTime := time.Now() for _, folder := range c.Folders { if folder.Retention > 0 { - if err := c.cleanupFolder(folder.Path); err != nil { + if err := c.cleanupFolder(folder.Path, 0); err != nil { c.conn.Log(logger.LevelError, "retention check failed, unable to cleanup folder %q", folder.Path) c.sendNotifications(time.Since(startTime), err) return err diff --git a/internal/common/dataretention_test.go b/internal/common/dataretention_test.go index 703f8c04..adc59314 100644 --- a/internal/common/dataretention_test.go +++ b/internal/common/dataretention_test.go @@ -28,6 +28,7 @@ import ( "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/smtp" + "github.com/drakkan/sftpgo/v2/internal/util" ) func TestRetentionValidation(t *testing.T) { @@ -272,7 +273,9 @@ func TestRetentionPermissionsAndGetFolder(t *testing.T) { conn.SetProtocol(ProtocolDataRetention) conn.ID = fmt.Sprintf("data_retention_%v", user.Username) check.conn = conn + assert.False(t, check.hasCleanupPerms(check.Folders[2].Path)) check.updateUserPermissions() + assert.True(t, check.hasCleanupPerms(check.Folders[2].Path)) assert.Equal(t, []string{dataprovider.PermListItems, dataprovider.PermDelete}, conn.User.Permissions["/"]) assert.Equal(t, []string{dataprovider.PermListItems}, conn.User.Permissions["/dir1"]) assert.Equal(t, []string{dataprovider.PermAny}, conn.User.Permissions["/dir2"]) @@ -391,8 +394,11 @@ func TestCleanupErrors(t *testing.T) { err := check.removeFile("missing file", nil) assert.Error(t, err) - err = check.cleanupFolder("/") + err = check.cleanupFolder("/", 0) assert.Error(t, err) + err = check.cleanupFolder("/", 1000) + assert.ErrorIs(t, err, util.ErrRecursionTooDeep) + assert.True(t, RetentionChecks.remove(user.Username)) } diff --git a/internal/common/eventmanager.go b/internal/common/eventmanager.go index e1e5cc8e..be4695d6 100644 --- a/internal/common/eventmanager.go +++ b/internal/common/eventmanager.go @@ -988,11 +988,16 @@ func getFileWriter(conn *BaseConnection, virtualPath string, expectedSize int64) return w, numFiles, truncatedSize, cancelFn, nil } -func addZipEntry(wr *zipWriterWrapper, conn *BaseConnection, entryPath, baseDir string) error { +func addZipEntry(wr *zipWriterWrapper, conn *BaseConnection, entryPath, baseDir string, recursion int) error { if entryPath == wr.Name { // skip the archive itself return nil } + if recursion >= util.MaxRecursion { + eventManagerLog(logger.LevelError, "unable to add zip entry %q, recursion too deep: %v", entryPath, recursion) + return util.ErrRecursionTooDeep + } + recursion++ info, err := conn.DoStat(entryPath, 1, false) if err != nil { eventManagerLog(logger.LevelError, "unable to add zip entry %q, stat error: %v", entryPath, err) @@ -1018,25 +1023,42 @@ func addZipEntry(wr *zipWriterWrapper, conn *BaseConnection, entryPath, baseDir eventManagerLog(logger.LevelError, "unable to create zip entry %q: %v", entryPath, err) return fmt.Errorf("unable to create zip entry %q: %w", entryPath, err) } - contents, err := conn.ListDir(entryPath) + lister, err := conn.ListDir(entryPath) if err != nil { - eventManagerLog(logger.LevelError, "unable to add zip entry %q, read dir error: %v", entryPath, err) + eventManagerLog(logger.LevelError, "unable to add zip entry %q, get dir lister error: %v", entryPath, err) return fmt.Errorf("unable to add zip entry %q: %w", entryPath, err) } - for _, info := range contents { - fullPath := util.CleanPath(path.Join(entryPath, info.Name())) - if err := addZipEntry(wr, conn, fullPath, baseDir); err != nil { - eventManagerLog(logger.LevelError, "unable to add zip entry: %v", err) - return err + defer lister.Close() + + for { + contents, err := lister.Next(vfs.ListerBatchSize) + finished := errors.Is(err, io.EOF) + if err := lister.convertError(err); err != nil { + eventManagerLog(logger.LevelError, "unable to add zip entry %q, read dir error: %v", entryPath, err) + return fmt.Errorf("unable to add zip entry %q: %w", entryPath, err) + } + for _, info := range contents { + fullPath := util.CleanPath(path.Join(entryPath, info.Name())) + if err := addZipEntry(wr, conn, fullPath, baseDir, recursion); err != nil { + eventManagerLog(logger.LevelError, "unable to add zip entry: %v", err) + return err + } + } + if finished { + return nil } } - return nil } if !info.Mode().IsRegular() { // we only allow regular files eventManagerLog(logger.LevelInfo, "skipping zip entry for non regular file %q", entryPath) return nil } + + return addFileToZip(wr, conn, entryPath, entryName, info.ModTime()) +} + +func addFileToZip(wr *zipWriterWrapper, conn *BaseConnection, entryPath, entryName string, modTime time.Time) error { reader, cancelFn, err := getFileReader(conn, entryPath) if err != nil { eventManagerLog(logger.LevelError, "unable to add zip entry %q, cannot open file: %v", entryPath, err) @@ -1048,7 +1070,7 @@ func addZipEntry(wr *zipWriterWrapper, conn *BaseConnection, entryPath, baseDir f, err := wr.Writer.CreateHeader(&zip.FileHeader{ Name: entryName, Method: zip.Deflate, - Modified: info.ModTime(), + Modified: modTime, }) if err != nil { eventManagerLog(logger.LevelError, "unable to create zip entry %q: %v", entryPath, err) @@ -1890,18 +1912,28 @@ func getArchiveBaseDir(paths []string) string { func getSizeForPath(conn *BaseConnection, p string, info os.FileInfo) (int64, error) { if info.IsDir() { var dirSize int64 - entries, err := conn.ListDir(p) + lister, err := conn.ListDir(p) if err != nil { return 0, err } - for _, entry := range entries { - size, err := getSizeForPath(conn, path.Join(p, entry.Name()), entry) - if err != nil { + defer lister.Close() + for { + entries, err := lister.Next(vfs.ListerBatchSize) + finished := errors.Is(err, io.EOF) + if err != nil && !finished { return 0, err } - dirSize += size + for _, entry := range entries { + size, err := getSizeForPath(conn, path.Join(p, entry.Name()), entry) + if err != nil { + return 0, err + } + dirSize += size + } + if finished { + return dirSize, nil + } } - return dirSize, nil } if info.Mode().IsRegular() { return info.Size(), nil @@ -1978,7 +2010,7 @@ func executeCompressFsActionForUser(c dataprovider.EventActionFsCompress, replac } startTime := time.Now() for _, item := range paths { - if err := addZipEntry(zipWriter, conn, item, baseDir); err != nil { + if err := addZipEntry(zipWriter, conn, item, baseDir, 0); err != nil { closeWriterAndUpdateQuota(writer, conn, name, "", numFiles, truncatedSize, err, operationUpload, startTime) //nolint:errcheck return err } diff --git a/internal/common/eventmanager_test.go b/internal/common/eventmanager_test.go index dea66671..2b5c7863 100644 --- a/internal/common/eventmanager_test.go +++ b/internal/common/eventmanager_test.go @@ -1835,7 +1835,7 @@ func TestFilesystemActionErrors(t *testing.T) { Writer: zip.NewWriter(bytes.NewBuffer(nil)), Entries: map[string]bool{}, } - err = addZipEntry(wr, conn, "/adir/sub/f.dat", "/adir/sub/sub") + err = addZipEntry(wr, conn, "/adir/sub/f.dat", "/adir/sub/sub", 0) assert.Error(t, err) assert.Contains(t, getErrorString(err), "is outside base dir") } diff --git a/internal/dataprovider/sqlqueries.go b/internal/dataprovider/sqlqueries.go index 0805216e..2ce95b5c 100644 --- a/internal/dataprovider/sqlqueries.go +++ b/internal/dataprovider/sqlqueries.go @@ -131,9 +131,9 @@ func getDefenderHostQuery() string { sqlTableDefenderHosts, sqlPlaceholders[0], sqlPlaceholders[1]) } -func getDefenderEventsQuery(hostIDS []int64) string { +func getDefenderEventsQuery(hostIDs []int64) string { var sb strings.Builder - for _, hID := range hostIDS { + for _, hID := range hostIDs { if sb.Len() == 0 { sb.WriteString("(") } else { diff --git a/internal/dataprovider/user.go b/internal/dataprovider/user.go index d9b1663d..a604876a 100644 --- a/internal/dataprovider/user.go +++ b/internal/dataprovider/user.go @@ -676,8 +676,7 @@ func (u *User) GetVirtualFoldersInPath(virtualPath string) map[string]bool { result := make(map[string]bool) for idx := range u.VirtualFolders { - v := &u.VirtualFolders[idx] - dirsForPath := util.GetDirsForVirtualPath(v.VirtualPath) + dirsForPath := util.GetDirsForVirtualPath(u.VirtualFolders[idx].VirtualPath) for index := range dirsForPath { d := dirsForPath[index] if d == "/" { @@ -716,13 +715,34 @@ func (u *User) hasVirtualDirs() bool { return numFolders > 0 } -// FilterListDir adds virtual folders and remove hidden items from the given files list +// GetVirtualFoldersInfo returns []os.FileInfo for virtual folders +func (u *User) GetVirtualFoldersInfo(virtualPath string) []os.FileInfo { + filter := u.getPatternsFilterForPath(virtualPath) + if !u.hasVirtualDirs() && filter.DenyPolicy != sdk.DenyPolicyHide { + return nil + } + vdirs := u.GetVirtualFoldersInPath(virtualPath) + result := make([]os.FileInfo, 0, len(vdirs)) + + for dir := range u.GetVirtualFoldersInPath(virtualPath) { + dirName := path.Base(dir) + if filter.DenyPolicy == sdk.DenyPolicyHide { + if !filter.CheckAllowed(dirName) { + continue + } + } + result = append(result, vfs.NewFileInfo(dirName, true, 0, time.Unix(0, 0), false)) + } + + return result +} + +// FilterListDir removes hidden items from the given files list func (u *User) FilterListDir(dirContents []os.FileInfo, virtualPath string) []os.FileInfo { filter := u.getPatternsFilterForPath(virtualPath) if !u.hasVirtualDirs() && filter.DenyPolicy != sdk.DenyPolicyHide { return dirContents } - vdirs := make(map[string]bool) for dir := range u.GetVirtualFoldersInPath(virtualPath) { dirName := path.Base(dir) @@ -735,36 +755,24 @@ func (u *User) FilterListDir(dirContents []os.FileInfo, virtualPath string) []os } validIdx := 0 - for index, fi := range dirContents { - for dir := range vdirs { - if fi.Name() == dir { - if !fi.IsDir() { - fi = vfs.NewFileInfo(dir, true, 0, time.Unix(0, 0), false) - dirContents[index] = fi + for idx := range dirContents { + fi := dirContents[idx] + + if fi.Name() != "." && fi.Name() != ".." { + if _, ok := vdirs[fi.Name()]; ok { + continue + } + if filter.DenyPolicy == sdk.DenyPolicyHide { + if !filter.CheckAllowed(fi.Name()) { + continue } - delete(vdirs, dir) - } - } - if filter.DenyPolicy == sdk.DenyPolicyHide { - if filter.CheckAllowed(fi.Name()) { - dirContents[validIdx] = fi - validIdx++ } } + dirContents[validIdx] = fi + validIdx++ } - if filter.DenyPolicy == sdk.DenyPolicyHide { - for idx := validIdx; idx < len(dirContents); idx++ { - dirContents[idx] = nil - } - dirContents = dirContents[:validIdx] - } - - for dir := range vdirs { - fi := vfs.NewFileInfo(dir, true, 0, time.Unix(0, 0), false) - dirContents = append(dirContents, fi) - } - return dirContents + return dirContents[:validIdx] } // IsMappedPath returns true if the specified filesystem path has a virtual folder mapping. diff --git a/internal/ftpd/ftpd_test.go b/internal/ftpd/ftpd_test.go index d1a1414b..f4bf0f71 100644 --- a/internal/ftpd/ftpd_test.go +++ b/internal/ftpd/ftpd_test.go @@ -2544,14 +2544,14 @@ func TestRename(t *testing.T) { assert.NoError(t, err) err = client.MakeDir(path.Join(otherDir, testDir)) assert.NoError(t, err) - code, response, err := client.SendCustomCommand(fmt.Sprintf("SITE CHMOD 0001 %v", otherDir)) + code, response, err := client.SendCommand(fmt.Sprintf("SITE CHMOD 0001 %v", otherDir)) assert.NoError(t, err) assert.Equal(t, ftp.StatusCommandOK, code) assert.Equal(t, "SITE CHMOD command successful", response) err = client.Rename(testDir, path.Join(otherDir, testDir)) assert.Error(t, err) - code, response, err = client.SendCustomCommand(fmt.Sprintf("SITE CHMOD 755 %v", otherDir)) + code, response, err = client.SendCommand(fmt.Sprintf("SITE CHMOD 755 %v", otherDir)) assert.NoError(t, err) assert.Equal(t, ftp.StatusCommandOK, code) assert.Equal(t, "SITE CHMOD command successful", response) @@ -2611,7 +2611,7 @@ func TestSymlink(t *testing.T) { assert.NoError(t, err) err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) assert.NoError(t, err) - code, _, err := client.SendCustomCommand(fmt.Sprintf("SITE SYMLINK %v %v", testFileName, testFileName+".link")) + code, _, err := client.SendCommand(fmt.Sprintf("SITE SYMLINK %v %v", testFileName, testFileName+".link")) assert.NoError(t, err) assert.Equal(t, ftp.StatusCommandOK, code) @@ -2622,15 +2622,15 @@ func TestSymlink(t *testing.T) { assert.NoError(t, err) err = client.MakeDir(path.Join(otherDir, testDir)) assert.NoError(t, err) - code, response, err := client.SendCustomCommand(fmt.Sprintf("SITE CHMOD 0001 %v", otherDir)) + code, response, err := client.SendCommand(fmt.Sprintf("SITE CHMOD 0001 %v", otherDir)) assert.NoError(t, err) assert.Equal(t, ftp.StatusCommandOK, code) assert.Equal(t, "SITE CHMOD command successful", response) - code, _, err = client.SendCustomCommand(fmt.Sprintf("SITE SYMLINK %v %v", testDir, path.Join(otherDir, testDir))) + code, _, err = client.SendCommand(fmt.Sprintf("SITE SYMLINK %v %v", testDir, path.Join(otherDir, testDir))) assert.NoError(t, err) assert.Equal(t, ftp.StatusFileUnavailable, code) - code, response, err = client.SendCustomCommand(fmt.Sprintf("SITE CHMOD 755 %v", otherDir)) + code, response, err = client.SendCommand(fmt.Sprintf("SITE CHMOD 755 %v", otherDir)) assert.NoError(t, err) assert.Equal(t, ftp.StatusCommandOK, code) assert.Equal(t, "SITE CHMOD command successful", response) @@ -2860,17 +2860,17 @@ func TestAllocateAvailable(t *testing.T) { assert.NoError(t, err) client, err := getFTPClient(user, false, nil) if assert.NoError(t, err) { - code, response, err := client.SendCustomCommand("allo 2000000") + code, response, err := client.SendCommand("allo 2000000") assert.NoError(t, err) assert.Equal(t, ftp.StatusCommandOK, code) assert.Equal(t, "Done !", response) - code, response, err = client.SendCustomCommand("AVBL /vdir") + code, response, err = client.SendCommand("AVBL /vdir") assert.NoError(t, err) assert.Equal(t, ftp.StatusFile, code) assert.Equal(t, "110", response) - code, _, err = client.SendCustomCommand("AVBL") + code, _, err = client.SendCommand("AVBL") assert.NoError(t, err) assert.Equal(t, ftp.StatusFile, code) @@ -2886,7 +2886,7 @@ func TestAllocateAvailable(t *testing.T) { testFileSize := user.QuotaSize - 1 err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) - code, response, err := client.SendCustomCommand("allo 1000") + code, response, err := client.SendCommand("allo 1000") assert.NoError(t, err) assert.Equal(t, ftp.StatusCommandOK, code) assert.Equal(t, "Done !", response) @@ -2894,7 +2894,7 @@ func TestAllocateAvailable(t *testing.T) { err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) assert.NoError(t, err) - code, response, err = client.SendCustomCommand("AVBL") + code, response, err = client.SendCommand("AVBL") assert.NoError(t, err) assert.Equal(t, ftp.StatusFile, code) assert.Equal(t, "1", response) @@ -2909,7 +2909,7 @@ func TestAllocateAvailable(t *testing.T) { assert.NoError(t, err) client, err = getFTPClient(user, false, nil) if assert.NoError(t, err) { - code, response, err := client.SendCustomCommand("AVBL") + code, response, err := client.SendCommand("AVBL") assert.NoError(t, err) assert.Equal(t, ftp.StatusFile, code) assert.Equal(t, "1", response) @@ -2925,7 +2925,7 @@ func TestAllocateAvailable(t *testing.T) { assert.NoError(t, err) client, err = getFTPClient(user, false, nil) if assert.NoError(t, err) { - code, response, err := client.SendCustomCommand("AVBL") + code, response, err := client.SendCommand("AVBL") assert.NoError(t, err) assert.Equal(t, ftp.StatusFile, code) assert.Equal(t, "5242880", response) @@ -2941,7 +2941,7 @@ func TestAllocateAvailable(t *testing.T) { assert.NoError(t, err) client, err = getFTPClient(user, false, nil) if assert.NoError(t, err) { - code, response, err := client.SendCustomCommand("AVBL") + code, response, err := client.SendCommand("AVBL") assert.NoError(t, err) assert.Equal(t, ftp.StatusFile, code) assert.Equal(t, "5242880", response) @@ -2958,12 +2958,12 @@ func TestAllocateAvailable(t *testing.T) { assert.NoError(t, err) client, err = getFTPClient(user, false, nil) if assert.NoError(t, err) { - code, response, err := client.SendCustomCommand("allo 10000") + code, response, err := client.SendCommand("allo 10000") assert.NoError(t, err) assert.Equal(t, ftp.StatusCommandOK, code) assert.Equal(t, "Done !", response) - code, response, err = client.SendCustomCommand("AVBL") + code, response, err = client.SendCommand("AVBL") assert.NoError(t, err) assert.Equal(t, ftp.StatusFile, code) assert.Equal(t, "100", response) @@ -2977,7 +2977,7 @@ func TestAllocateAvailable(t *testing.T) { assert.NoError(t, err) client, err = getFTPClient(user, false, nil) if assert.NoError(t, err) { - code, response, err := client.SendCustomCommand("AVBL") + code, response, err := client.SendCommand("AVBL") assert.NoError(t, err) assert.Equal(t, ftp.StatusFile, code) assert.Equal(t, "0", response) @@ -2989,7 +2989,7 @@ func TestAllocateAvailable(t *testing.T) { assert.NoError(t, err) client, err = getFTPClient(user, false, nil) if assert.NoError(t, err) { - code, response, err := client.SendCustomCommand("AVBL") + code, response, err := client.SendCommand("AVBL") assert.NoError(t, err) assert.Equal(t, ftp.StatusFile, code) assert.Equal(t, "1", response) @@ -3013,7 +3013,7 @@ func TestAvailableSFTPFs(t *testing.T) { assert.NoError(t, err) client, err := getFTPClient(sftpUser, false, nil) if assert.NoError(t, err) { - code, response, err := client.SendCustomCommand("AVBL /") + code, response, err := client.SendCommand("AVBL /") assert.NoError(t, err) assert.Equal(t, ftp.StatusFile, code) avblSize, err := strconv.ParseInt(response, 10, 64) @@ -3051,7 +3051,7 @@ func TestChtimes(t *testing.T) { assert.NoError(t, err) mtime := time.Now().Format("20060102150405") - code, response, err := client.SendCustomCommand(fmt.Sprintf("MFMT %v %v", mtime, testFileName)) + code, response, err := client.SendCommand(fmt.Sprintf("MFMT %v %v", mtime, testFileName)) assert.NoError(t, err) assert.Equal(t, ftp.StatusFile, code) assert.Equal(t, fmt.Sprintf("Modify=%v; %v", mtime, testFileName), response) @@ -3097,7 +3097,7 @@ func TestChown(t *testing.T) { assert.NoError(t, err) err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) assert.NoError(t, err) - code, response, err := client.SendCustomCommand(fmt.Sprintf("SITE CHOWN 1000:1000 %v", testFileName)) + code, response, err := client.SendCommand(fmt.Sprintf("SITE CHOWN 1000:1000 %v", testFileName)) assert.NoError(t, err) assert.Equal(t, ftp.StatusFileUnavailable, code) assert.Equal(t, "Couldn't chown: operation unsupported", response) @@ -3135,7 +3135,7 @@ func TestChmod(t *testing.T) { err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) assert.NoError(t, err) - code, response, err := client.SendCustomCommand(fmt.Sprintf("SITE CHMOD 600 %v", testFileName)) + code, response, err := client.SendCommand(fmt.Sprintf("SITE CHMOD 600 %v", testFileName)) assert.NoError(t, err) assert.Equal(t, ftp.StatusCommandOK, code) assert.Equal(t, "SITE CHMOD command successful", response) @@ -3182,7 +3182,7 @@ func TestCombineDisabled(t *testing.T) { err = checkBasicFTP(client) assert.NoError(t, err) - code, response, err := client.SendCustomCommand("COMB file file.1 file.2") + code, response, err := client.SendCommand("COMB file file.1 file.2") assert.NoError(t, err) assert.Equal(t, ftp.StatusNotImplemented, code) assert.Equal(t, "COMB support is disabled", response) @@ -3208,12 +3208,12 @@ func TestActiveModeDisabled(t *testing.T) { if assert.NoError(t, err) { err = checkBasicFTP(client) assert.NoError(t, err) - code, response, err := client.SendCustomCommand("PORT 10,2,0,2,4,31") + code, response, err := client.SendCommand("PORT 10,2,0,2,4,31") assert.NoError(t, err) assert.Equal(t, ftp.StatusNotAvailable, code) assert.Equal(t, "PORT command is disabled", response) - code, response, err = client.SendCustomCommand("EPRT |1|132.235.1.2|6275|") + code, response, err = client.SendCommand("EPRT |1|132.235.1.2|6275|") assert.NoError(t, err) assert.Equal(t, ftp.StatusNotAvailable, code) assert.Equal(t, "EPRT command is disabled", response) @@ -3224,12 +3224,12 @@ func TestActiveModeDisabled(t *testing.T) { client, err = getFTPClient(user, false, nil) if assert.NoError(t, err) { - code, response, err := client.SendCustomCommand("PORT 10,2,0,2,4,31") + code, response, err := client.SendCommand("PORT 10,2,0,2,4,31") assert.NoError(t, err) assert.Equal(t, ftp.StatusBadArguments, code) assert.Equal(t, "Your request does not meet the configured security requirements", response) - code, response, err = client.SendCustomCommand("EPRT |1|132.235.1.2|6275|") + code, response, err = client.SendCommand("EPRT |1|132.235.1.2|6275|") assert.NoError(t, err) assert.Equal(t, ftp.StatusBadArguments, code) assert.Equal(t, "Your request does not meet the configured security requirements", response) @@ -3253,7 +3253,7 @@ func TestSITEDisabled(t *testing.T) { err = checkBasicFTP(client) assert.NoError(t, err) - code, response, err := client.SendCustomCommand("SITE CHMOD 600 afile.txt") + code, response, err := client.SendCommand("SITE CHMOD 600 afile.txt") assert.NoError(t, err) assert.Equal(t, ftp.StatusBadCommand, code) assert.Equal(t, "SITE support is disabled", response) @@ -3298,12 +3298,12 @@ func TestHASH(t *testing.T) { err = f.Close() assert.NoError(t, err) - code, response, err := client.SendCustomCommand(fmt.Sprintf("XSHA256 %v", testFileName)) + code, response, err := client.SendCommand(fmt.Sprintf("XSHA256 %v", testFileName)) assert.NoError(t, err) assert.Equal(t, ftp.StatusRequestedFileActionOK, code) assert.Contains(t, response, hash) - code, response, err = client.SendCustomCommand(fmt.Sprintf("HASH %v", testFileName)) + code, response, err = client.SendCommand(fmt.Sprintf("HASH %v", testFileName)) assert.NoError(t, err) assert.Equal(t, ftp.StatusFile, code) assert.Contains(t, response, hash) @@ -3359,7 +3359,7 @@ func TestCombine(t *testing.T) { err = ftpUploadFile(testFilePath, testFileName+".2", testFileSize, client, 0) assert.NoError(t, err) - code, response, err := client.SendCustomCommand(fmt.Sprintf("COMB %v %v %v", testFileName, testFileName+".1", testFileName+".2")) + code, response, err := client.SendCommand(fmt.Sprintf("COMB %v %v %v", testFileName, testFileName+".1", testFileName+".2")) assert.NoError(t, err) if user.Username == defaultUsername { assert.Equal(t, ftp.StatusRequestedFileActionOK, code) diff --git a/internal/ftpd/handler.go b/internal/ftpd/handler.go index 1d56fd4c..036c3977 100644 --- a/internal/ftpd/handler.go +++ b/internal/ftpd/handler.go @@ -291,7 +291,7 @@ func (c *Connection) Symlink(oldname, newname string) error { } // ReadDir implements ClientDriverExtensionFilelist -func (c *Connection) ReadDir(name string) ([]os.FileInfo, error) { +func (c *Connection) ReadDir(name string) (ftpserver.DirLister, error) { c.UpdateLastActivity() if c.doWildcardListDir { @@ -302,7 +302,17 @@ func (c *Connection) ReadDir(name string) ([]os.FileInfo, error) { // - dir*/*.xml is not supported name = path.Dir(name) c.clientContext.SetListPath(name) - return c.getListDirWithWildcards(name, baseName) + lister, err := c.ListDir(name) + if err != nil { + return nil, err + } + return &patternDirLister{ + DirLister: lister, + pattern: baseName, + lastCommand: c.clientContext.GetLastCommand(), + dirName: name, + connectionPath: c.clientContext.Path(), + }, nil } return c.ListDir(name) @@ -506,31 +516,6 @@ func (c *Connection) handleFTPUploadToExistingFile(fs vfs.Fs, flags int, resolve return t, nil } -func (c *Connection) getListDirWithWildcards(dirName, pattern string) ([]os.FileInfo, error) { - files, err := c.ListDir(dirName) - if err != nil { - return files, err - } - validIdx := 0 - var relativeBase string - if c.clientContext.GetLastCommand() != "NLST" { - relativeBase = getPathRelativeTo(c.clientContext.Path(), dirName) - } - for _, fi := range files { - match, err := path.Match(pattern, fi.Name()) - if err != nil { - return files, err - } - if match { - files[validIdx] = vfs.NewFileInfo(path.Join(relativeBase, fi.Name()), fi.IsDir(), fi.Size(), - fi.ModTime(), true) - validIdx++ - } - } - - return files[:validIdx], nil -} - func (c *Connection) isListDirWithWildcards(name string) bool { if strings.ContainsAny(name, "*?[]^") { lastCommand := c.clientContext.GetLastCommand() @@ -559,3 +544,40 @@ func getPathRelativeTo(base, target string) string { base = path.Dir(path.Clean(base)) } } + +type patternDirLister struct { + vfs.DirLister + pattern string + lastCommand string + dirName string + connectionPath string +} + +func (l *patternDirLister) Next(limit int) ([]os.FileInfo, error) { + for { + files, err := l.DirLister.Next(limit) + if len(files) == 0 { + return files, err + } + validIdx := 0 + var relativeBase string + if l.lastCommand != "NLST" { + relativeBase = getPathRelativeTo(l.connectionPath, l.dirName) + } + for _, fi := range files { + match, errMatch := path.Match(l.pattern, fi.Name()) + if errMatch != nil { + return nil, errMatch + } + if match { + files[validIdx] = vfs.NewFileInfo(path.Join(relativeBase, fi.Name()), fi.IsDir(), fi.Size(), + fi.ModTime(), true) + validIdx++ + } + } + files = files[:validIdx] + if err != nil || len(files) > 0 { + return files, err + } + } +} diff --git a/internal/httpd/api_http_user.go b/internal/httpd/api_http_user.go index 3acea111..5930caa2 100644 --- a/internal/httpd/api_http_user.go +++ b/internal/httpd/api_http_user.go @@ -74,12 +74,12 @@ func readUserFolder(w http.ResponseWriter, r *http.Request) { defer common.Connections.Remove(connection.GetID()) name := connection.User.GetCleanedPath(r.URL.Query().Get("path")) - contents, err := connection.ReadDir(name) + lister, err := connection.ReadDir(name) if err != nil { - sendAPIResponse(w, r, err, "Unable to get directory contents", getMappedStatusCode(err)) + sendAPIResponse(w, r, err, "Unable to get directory lister", getMappedStatusCode(err)) return } - renderAPIDirContents(w, r, contents, false) + renderAPIDirContents(w, lister, false) } func createUserDir(w http.ResponseWriter, r *http.Request) { diff --git a/internal/httpd/api_shares.go b/internal/httpd/api_shares.go index c32838fc..d742cf67 100644 --- a/internal/httpd/api_shares.go +++ b/internal/httpd/api_shares.go @@ -213,12 +213,12 @@ func (s *httpdServer) readBrowsableShareContents(w http.ResponseWriter, r *http. } defer common.Connections.Remove(connection.GetID()) - contents, err := connection.ReadDir(name) + lister, err := connection.ReadDir(name) if err != nil { - sendAPIResponse(w, r, err, "Unable to get directory contents", getMappedStatusCode(err)) + sendAPIResponse(w, r, err, "Unable to get directory lister", getMappedStatusCode(err)) return } - renderAPIDirContents(w, r, contents, true) + renderAPIDirContents(w, lister, true) } func (s *httpdServer) downloadBrowsableSharedFile(w http.ResponseWriter, r *http.Request) { diff --git a/internal/httpd/api_utils.go b/internal/httpd/api_utils.go index b317dcf6..d3e74f94 100644 --- a/internal/httpd/api_utils.go +++ b/internal/httpd/api_utils.go @@ -17,6 +17,7 @@ package httpd import ( "bytes" "context" + "encoding/json" "errors" "fmt" "io" @@ -44,6 +45,7 @@ import ( "github.com/drakkan/sftpgo/v2/internal/plugin" "github.com/drakkan/sftpgo/v2/internal/smtp" "github.com/drakkan/sftpgo/v2/internal/util" + "github.com/drakkan/sftpgo/v2/internal/vfs" ) type pwdChange struct { @@ -280,23 +282,40 @@ func getSearchFilters(w http.ResponseWriter, r *http.Request) (int, int, string, return limit, offset, order, err } -func renderAPIDirContents(w http.ResponseWriter, r *http.Request, contents []os.FileInfo, omitNonRegularFiles bool) { - results := make([]map[string]any, 0, len(contents)) - for _, info := range contents { - if omitNonRegularFiles && !info.Mode().IsDir() && !info.Mode().IsRegular() { - continue +func renderAPIDirContents(w http.ResponseWriter, lister vfs.DirLister, omitNonRegularFiles bool) { + defer lister.Close() + + dataGetter := func(limit, _ int) ([]byte, int, error) { + contents, err := lister.Next(limit) + if errors.Is(err, io.EOF) { + err = nil } - res := make(map[string]any) - res["name"] = info.Name() - if info.Mode().IsRegular() { - res["size"] = info.Size() + if err != nil { + return nil, 0, err } - res["mode"] = info.Mode() - res["last_modified"] = info.ModTime().UTC().Format(time.RFC3339) - results = append(results, res) + results := make([]map[string]any, 0, len(contents)) + for _, info := range contents { + if omitNonRegularFiles && !info.Mode().IsDir() && !info.Mode().IsRegular() { + continue + } + res := make(map[string]any) + res["name"] = info.Name() + if info.Mode().IsRegular() { + res["size"] = info.Size() + } + res["mode"] = info.Mode() + res["last_modified"] = info.ModTime().UTC().Format(time.RFC3339) + results = append(results, res) + } + data, err := json.Marshal(results) + count := limit + if len(results) == 0 { + count = 0 + } + return data, count, err } - render.JSON(w, r, results) + streamJSONArray(w, defaultQueryLimit, dataGetter) } func streamData(w io.Writer, data []byte) { @@ -355,7 +374,7 @@ func renderCompressedFiles(w http.ResponseWriter, conn *Connection, baseDir stri for _, file := range files { fullPath := util.CleanPath(path.Join(baseDir, file)) - if err := addZipEntry(wr, conn, fullPath, baseDir); err != nil { + if err := addZipEntry(wr, conn, fullPath, baseDir, 0); err != nil { if share != nil { dataprovider.UpdateShareLastUse(share, -1) //nolint:errcheck } @@ -371,7 +390,12 @@ func renderCompressedFiles(w http.ResponseWriter, conn *Connection, baseDir stri } } -func addZipEntry(wr *zip.Writer, conn *Connection, entryPath, baseDir string) error { +func addZipEntry(wr *zip.Writer, conn *Connection, entryPath, baseDir string, recursion int) error { + if recursion >= util.MaxRecursion { + conn.Log(logger.LevelDebug, "unable to add zip entry %q, recursion too depth: %d", entryPath, recursion) + return util.ErrRecursionTooDeep + } + recursion++ info, err := conn.Stat(entryPath, 1) if err != nil { conn.Log(logger.LevelDebug, "unable to add zip entry %q, stat error: %v", entryPath, err) @@ -392,24 +416,39 @@ func addZipEntry(wr *zip.Writer, conn *Connection, entryPath, baseDir string) er conn.Log(logger.LevelError, "unable to create zip entry %q: %v", entryPath, err) return err } - contents, err := conn.ReadDir(entryPath) + lister, err := conn.ReadDir(entryPath) if err != nil { - conn.Log(logger.LevelDebug, "unable to add zip entry %q, read dir error: %v", entryPath, err) + conn.Log(logger.LevelDebug, "unable to add zip entry %q, get list dir error: %v", entryPath, err) return err } - for _, info := range contents { - fullPath := util.CleanPath(path.Join(entryPath, info.Name())) - if err := addZipEntry(wr, conn, fullPath, baseDir); err != nil { + defer lister.Close() + + for { + contents, err := lister.Next(vfs.ListerBatchSize) + finished := errors.Is(err, io.EOF) + if err != nil && !finished { return err } + for _, info := range contents { + fullPath := util.CleanPath(path.Join(entryPath, info.Name())) + if err := addZipEntry(wr, conn, fullPath, baseDir, recursion); err != nil { + return err + } + } + if finished { + return nil + } } - return nil } if !info.Mode().IsRegular() { // we only allow regular files conn.Log(logger.LevelInfo, "skipping zip entry for non regular file %q", entryPath) return nil } + return addFileToZipEntry(wr, conn, entryPath, entryName, info) +} + +func addFileToZipEntry(wr *zip.Writer, conn *Connection, entryPath, entryName string, info os.FileInfo) error { reader, err := conn.getFileReader(entryPath, 0, http.MethodGet) if err != nil { conn.Log(logger.LevelDebug, "unable to add zip entry %q, cannot open file: %v", entryPath, err) diff --git a/internal/httpd/handler.go b/internal/httpd/handler.go index 4313585b..d417cc4a 100644 --- a/internal/httpd/handler.go +++ b/internal/httpd/handler.go @@ -88,7 +88,7 @@ func (c *Connection) Stat(name string, mode int) (os.FileInfo, error) { } // ReadDir returns a list of directory entries -func (c *Connection) ReadDir(name string) ([]os.FileInfo, error) { +func (c *Connection) ReadDir(name string) (vfs.DirLister, error) { c.UpdateLastActivity() return c.ListDir(name) diff --git a/internal/httpd/httpd_test.go b/internal/httpd/httpd_test.go index 5c538101..fbb1614e 100644 --- a/internal/httpd/httpd_test.go +++ b/internal/httpd/httpd_test.go @@ -16028,7 +16028,7 @@ func TestWebGetFiles(t *testing.T) { setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) - assert.Contains(t, rr.Body.String(), "Unable to get directory contents") + assert.Contains(t, rr.Body.String(), "Unable to get directory lister") req, _ = http.NewRequest(http.MethodGet, webClientFilesPath+"?path="+testFileName, nil) setJWTCookieForReq(req, webToken) diff --git a/internal/httpd/internal_test.go b/internal/httpd/internal_test.go index 54a7b148..c0c95f6a 100644 --- a/internal/httpd/internal_test.go +++ b/internal/httpd/internal_test.go @@ -2196,7 +2196,7 @@ func TestRecoverer(t *testing.T) { } func TestStreamJSONArray(t *testing.T) { - dataGetter := func(limit, offset int) ([]byte, int, error) { + dataGetter := func(_, _ int) ([]byte, int, error) { return nil, 0, nil } rr := httptest.NewRecorder() @@ -2268,12 +2268,14 @@ func TestZipErrors(t *testing.T) { assert.Contains(t, err.Error(), "write error") } - err = addZipEntry(wr, connection, "/"+filepath.Base(testDir), "/") + err = addZipEntry(wr, connection, "/"+filepath.Base(testDir), "/", 0) if assert.Error(t, err) { assert.Contains(t, err.Error(), "write error") } + err = addZipEntry(wr, connection, "/"+filepath.Base(testDir), "/", 2000) + assert.ErrorIs(t, err, util.ErrRecursionTooDeep) - err = addZipEntry(wr, connection, "/"+filepath.Base(testDir), path.Join("/", filepath.Base(testDir), "dir")) + err = addZipEntry(wr, connection, "/"+filepath.Base(testDir), path.Join("/", filepath.Base(testDir), "dir"), 0) if assert.Error(t, err) { assert.Contains(t, err.Error(), "is outside base dir") } @@ -2282,14 +2284,14 @@ func TestZipErrors(t *testing.T) { err = os.WriteFile(testFilePath, util.GenerateRandomBytes(65535), os.ModePerm) assert.NoError(t, err) err = addZipEntry(wr, connection, path.Join("/", filepath.Base(testDir), filepath.Base(testFilePath)), - "/"+filepath.Base(testDir)) + "/"+filepath.Base(testDir), 0) if assert.Error(t, err) { assert.Contains(t, err.Error(), "write error") } connection.User.Permissions["/"] = []string{dataprovider.PermListItems} err = addZipEntry(wr, connection, path.Join("/", filepath.Base(testDir), filepath.Base(testFilePath)), - "/"+filepath.Base(testDir)) + "/"+filepath.Base(testDir), 0) assert.ErrorIs(t, err, os.ErrPermission) // creating a virtual folder to a missing path stat is ok but readdir fails @@ -2301,14 +2303,14 @@ func TestZipErrors(t *testing.T) { }) connection.User = user wr = zip.NewWriter(bytes.NewBuffer(make([]byte, 0))) - err = addZipEntry(wr, connection, user.VirtualFolders[0].VirtualPath, "/") + err = addZipEntry(wr, connection, user.VirtualFolders[0].VirtualPath, "/", 0) assert.Error(t, err) user.Filters.FilePatterns = append(user.Filters.FilePatterns, sdk.PatternsFilter{ Path: "/", DeniedPatterns: []string{"*.zip"}, }) - err = addZipEntry(wr, connection, "/"+filepath.Base(testDir), "/") + err = addZipEntry(wr, connection, "/"+filepath.Base(testDir), "/", 0) assert.ErrorIs(t, err, os.ErrPermission) err = os.RemoveAll(testDir) diff --git a/internal/httpd/webclient.go b/internal/httpd/webclient.go index e8a5f5bd..03a8a409 100644 --- a/internal/httpd/webclient.go +++ b/internal/httpd/webclient.go @@ -958,33 +958,50 @@ func (s *httpdServer) handleShareGetDirContents(w http.ResponseWriter, r *http.R } defer common.Connections.Remove(connection.GetID()) - contents, err := connection.ReadDir(name) + lister, err := connection.ReadDir(name) if err != nil { sendAPIResponse(w, r, err, getI18NErrorString(err, util.I18nErrorDirListGeneric), getMappedStatusCode(err)) return } - results := make([]map[string]any, 0, len(contents)) - for _, info := range contents { - if !info.Mode().IsDir() && !info.Mode().IsRegular() { - continue + defer lister.Close() + + dataGetter := func(limit, _ int) ([]byte, int, error) { + contents, err := lister.Next(limit) + if errors.Is(err, io.EOF) { + err = nil } - res := make(map[string]any) - if info.IsDir() { - res["type"] = "1" - res["size"] = "" - } else { - res["type"] = "2" - res["size"] = info.Size() + if err != nil { + return nil, 0, err } - res["meta"] = fmt.Sprintf("%v_%v", res["type"], info.Name()) - res["name"] = info.Name() - res["url"] = getFileObjectURL(share.GetRelativePath(name), info.Name(), - path.Join(webClientPubSharesPath, share.ShareID, "browse")) - res["last_modified"] = getFileObjectModTime(info.ModTime()) - results = append(results, res) + results := make([]map[string]any, 0, len(contents)) + for _, info := range contents { + if !info.Mode().IsDir() && !info.Mode().IsRegular() { + continue + } + res := make(map[string]any) + if info.IsDir() { + res["type"] = "1" + res["size"] = "" + } else { + res["type"] = "2" + res["size"] = info.Size() + } + res["meta"] = fmt.Sprintf("%v_%v", res["type"], info.Name()) + res["name"] = info.Name() + res["url"] = getFileObjectURL(share.GetRelativePath(name), info.Name(), + path.Join(webClientPubSharesPath, share.ShareID, "browse")) + res["last_modified"] = getFileObjectModTime(info.ModTime()) + results = append(results, res) + } + data, err := json.Marshal(results) + count := limit + if len(results) == 0 { + count = 0 + } + return data, count, err } - render.JSON(w, r, results) + streamJSONArray(w, defaultQueryLimit, dataGetter) } func (s *httpdServer) handleClientUploadToShare(w http.ResponseWriter, r *http.Request) { @@ -1146,43 +1163,59 @@ func (s *httpdServer) handleClientGetDirContents(w http.ResponseWriter, r *http. defer common.Connections.Remove(connection.GetID()) name := connection.User.GetCleanedPath(r.URL.Query().Get("path")) - contents, err := connection.ReadDir(name) + lister, err := connection.ReadDir(name) if err != nil { statusCode := getMappedStatusCode(err) sendAPIResponse(w, r, err, i18nListDirMsg(statusCode), statusCode) return } + defer lister.Close() dirTree := r.URL.Query().Get("dirtree") == "1" - results := make([]map[string]any, 0, len(contents)) - for _, info := range contents { - res := make(map[string]any) - res["url"] = getFileObjectURL(name, info.Name(), webClientFilesPath) - if info.IsDir() { - res["type"] = "1" - res["size"] = "" - res["dir_path"] = url.QueryEscape(path.Join(name, info.Name())) - } else { - if dirTree { - continue - } - res["type"] = "2" - if info.Mode()&os.ModeSymlink != 0 { + dataGetter := func(limit, _ int) ([]byte, int, error) { + contents, err := lister.Next(limit) + if errors.Is(err, io.EOF) { + err = nil + } + if err != nil { + return nil, 0, err + } + results := make([]map[string]any, 0, len(contents)) + for _, info := range contents { + res := make(map[string]any) + res["url"] = getFileObjectURL(name, info.Name(), webClientFilesPath) + if info.IsDir() { + res["type"] = "1" res["size"] = "" + res["dir_path"] = url.QueryEscape(path.Join(name, info.Name())) } else { - res["size"] = info.Size() - if info.Size() < httpdMaxEditFileSize { - res["edit_url"] = strings.Replace(res["url"].(string), webClientFilesPath, webClientEditFilePath, 1) + if dirTree { + continue + } + res["type"] = "2" + if info.Mode()&os.ModeSymlink != 0 { + res["size"] = "" + } else { + res["size"] = info.Size() + if info.Size() < httpdMaxEditFileSize { + res["edit_url"] = strings.Replace(res["url"].(string), webClientFilesPath, webClientEditFilePath, 1) + } } } + res["meta"] = fmt.Sprintf("%v_%v", res["type"], info.Name()) + res["name"] = info.Name() + res["last_modified"] = getFileObjectModTime(info.ModTime()) + results = append(results, res) } - res["meta"] = fmt.Sprintf("%v_%v", res["type"], info.Name()) - res["name"] = info.Name() - res["last_modified"] = getFileObjectModTime(info.ModTime()) - results = append(results, res) + data, err := json.Marshal(results) + count := limit + if len(results) == 0 { + count = 0 + } + return data, count, err } - render.JSON(w, r, results) + streamJSONArray(w, defaultQueryLimit, dataGetter) } func (s *httpdServer) handleClientGetFiles(w http.ResponseWriter, r *http.Request) { @@ -1917,27 +1950,45 @@ func doCheckExist(w http.ResponseWriter, r *http.Request, connection *Connection return } - contents, err := connection.ListDir(name) + lister, err := connection.ListDir(name) if err != nil { sendAPIResponse(w, r, err, "Unable to get directory contents", getMappedStatusCode(err)) return } - existing := make([]map[string]any, 0) - for _, info := range contents { - if util.Contains(filesList.Files, info.Name()) { - res := make(map[string]any) - res["name"] = info.Name() - if info.IsDir() { - res["type"] = "1" - res["size"] = "" - } else { - res["type"] = "2" - res["size"] = info.Size() - } - existing = append(existing, res) + defer lister.Close() + + dataGetter := func(limit, _ int) ([]byte, int, error) { + contents, err := lister.Next(limit) + if errors.Is(err, io.EOF) { + err = nil } + if err != nil { + return nil, 0, err + } + existing := make([]map[string]any, 0) + for _, info := range contents { + if util.Contains(filesList.Files, info.Name()) { + res := make(map[string]any) + res["name"] = info.Name() + if info.IsDir() { + res["type"] = "1" + res["size"] = "" + } else { + res["type"] = "2" + res["size"] = info.Size() + } + existing = append(existing, res) + } + } + data, err := json.Marshal(existing) + count := limit + if len(existing) == 0 { + count = 0 + } + return data, count, err } - render.JSON(w, r, existing) + + streamJSONArray(w, defaultQueryLimit, dataGetter) } func checkShareRedirectURL(next, base string) (bool, string) { diff --git a/internal/service/service.go b/internal/service/service.go index 2f7915a7..62bc17df 100644 --- a/internal/service/service.go +++ b/internal/service/service.go @@ -91,8 +91,8 @@ func (s *Service) initLogger() { // Start initializes and starts the service func (s *Service) Start(disableAWSInstallationCode bool) error { s.initLogger() - logger.Info(logSender, "", "starting SFTPGo %v, config dir: %v, config file: %v, log max size: %v log max backups: %v "+ - "log max age: %v log level: %v, log compress: %v, log utc time: %v, load data from: %q, grace time: %d secs", + logger.Info(logSender, "", "starting SFTPGo %s, config dir: %s, config file: %s, log max size: %d log max backups: %d "+ + "log max age: %d log level: %s, log compress: %t, log utc time: %t, load data from: %q, grace time: %d secs", version.GetAsString(), s.ConfigDir, s.ConfigFile, s.LogMaxSize, s.LogMaxBackups, s.LogMaxAge, s.LogLevel, s.LogCompress, s.LogUTCTime, s.LoadDataFrom, graceTime) // in portable mode we don't read configuration from file diff --git a/internal/sftpd/handler.go b/internal/sftpd/handler.go index cf00d23c..2dfdab05 100644 --- a/internal/sftpd/handler.go +++ b/internal/sftpd/handler.go @@ -216,16 +216,16 @@ func (c *Connection) Filelist(request *sftp.Request) (sftp.ListerAt, error) { switch request.Method { case "List": - files, err := c.ListDir(request.Filepath) + lister, err := c.ListDir(request.Filepath) if err != nil { return nil, err } modTime := time.Unix(0, 0) if request.Filepath != "/" || c.folderPrefix != "" { - files = util.PrependFileInfo(files, vfs.NewFileInfo("..", true, 0, modTime, false)) + lister.Add(vfs.NewFileInfo("..", true, 0, modTime, false)) } - files = util.PrependFileInfo(files, vfs.NewFileInfo(".", true, 0, modTime, false)) - return listerAt(files), nil + lister.Add(vfs.NewFileInfo(".", true, 0, modTime, false)) + return lister, nil case "Stat": if !c.User.HasPerm(dataprovider.PermListItems, path.Dir(request.Filepath)) { return nil, sftp.ErrSSHFxPermissionDenied diff --git a/internal/sftpd/internal_test.go b/internal/sftpd/internal_test.go index 71394be9..285dbf67 100644 --- a/internal/sftpd/internal_test.go +++ b/internal/sftpd/internal_test.go @@ -1294,6 +1294,17 @@ func TestSCPProtocolMessages(t *testing.T) { if assert.Error(t, err) { assert.Equal(t, protocolErrorMsg, err.Error()) } + + mockSSHChannel = MockChannel{ + Buffer: bytes.NewBuffer(respBuffer), + StdErrBuffer: bytes.NewBuffer(stdErrBuf), + ReadError: nil, + WriteError: writeErr, + } + scpCommand.connection.channel = &mockSSHChannel + + err = scpCommand.downloadDirs(nil, nil) + assert.ErrorIs(t, err, writeErr) } func TestSCPTestDownloadProtocolMessages(t *testing.T) { diff --git a/internal/sftpd/scp.go b/internal/sftpd/scp.go index cf662fc9..be4079ed 100644 --- a/internal/sftpd/scp.go +++ b/internal/sftpd/scp.go @@ -384,47 +384,65 @@ func (c *scpCommand) handleRecursiveDownload(fs vfs.Fs, dirPath, virtualPath str if err != nil { return err } - files, err := fs.ReadDir(dirPath) + // dirPath is a fs path, not a virtual path + lister, err := fs.ReadDir(dirPath) if err != nil { c.sendErrorMessage(fs, err) return err } - files = c.connection.User.FilterListDir(files, fs.GetRelativePath(dirPath)) + defer lister.Close() + + vdirs := c.connection.User.GetVirtualFoldersInfo(virtualPath) + var dirs []string - for _, file := range files { - filePath := fs.GetRelativePath(fs.Join(dirPath, file.Name())) - if file.Mode().IsRegular() || file.Mode()&os.ModeSymlink != 0 { - err = c.handleDownload(filePath) - if err != nil { - break - } - } else if file.IsDir() { - dirs = append(dirs, filePath) + for { + files, err := lister.Next(vfs.ListerBatchSize) + finished := errors.Is(err, io.EOF) + if err != nil && !finished { + c.sendErrorMessage(fs, err) + return err } - } - if err != nil { - c.sendErrorMessage(fs, err) - return err - } - for _, dir := range dirs { - err = c.handleDownload(dir) - if err != nil { + files = c.connection.User.FilterListDir(files, fs.GetRelativePath(dirPath)) + if len(vdirs) > 0 { + files = append(files, vdirs...) + vdirs = nil + } + for _, file := range files { + filePath := fs.GetRelativePath(fs.Join(dirPath, file.Name())) + if file.Mode().IsRegular() || file.Mode()&os.ModeSymlink != 0 { + err = c.handleDownload(filePath) + if err != nil { + c.sendErrorMessage(fs, err) + return err + } + } else if file.IsDir() { + dirs = append(dirs, filePath) + } + } + if finished { break } } - if err != nil { + lister.Close() + + return c.downloadDirs(fs, dirs) + } + err = errors.New("unable to send directory for non recursive copy") + c.sendErrorMessage(nil, err) + return err +} + +func (c *scpCommand) downloadDirs(fs vfs.Fs, dirs []string) error { + for _, dir := range dirs { + if err := c.handleDownload(dir); err != nil { c.sendErrorMessage(fs, err) return err } - err = c.sendProtocolMessage("E\n") - if err != nil { - return err - } - return c.readConfirmationMessage() } - err = fmt.Errorf("unable to send directory for non recursive copy") - c.sendErrorMessage(nil, err) - return err + if err := c.sendProtocolMessage("E\n"); err != nil { + return err + } + return c.readConfirmationMessage() } func (c *scpCommand) sendDownloadFileData(fs vfs.Fs, filePath string, stat os.FileInfo, transfer *transfer) error { diff --git a/internal/sftpd/server.go b/internal/sftpd/server.go index 20a574b6..e6754150 100644 --- a/internal/sftpd/server.go +++ b/internal/sftpd/server.go @@ -376,6 +376,7 @@ func (c *Configuration) Initialize(configDir string) error { c.loadModuli(configDir) sftp.SetSFTPExtensions(sftpExtensions...) //nolint:errcheck // we configure valid SFTP Extensions so we cannot get an error + sftp.MaxFilelist = vfs.ListerBatchSize if err := c.configureSecurityOptions(serverConfig); err != nil { return err diff --git a/internal/util/errors.go b/internal/util/errors.go index b8508b1d..c2bdfbce 100644 --- a/internal/util/errors.go +++ b/internal/util/errors.go @@ -15,6 +15,7 @@ package util import ( + "errors" "fmt" ) @@ -24,12 +25,16 @@ const ( "sftpgo serve -c \"\"" ) +// MaxRecursion defines the maximum number of allowed recursions +const MaxRecursion = 1000 + // errors definitions var ( - ErrValidation = NewValidationError("") - ErrNotFound = NewRecordNotFoundError("") - ErrMethodDisabled = NewMethodDisabledError("") - ErrGeneric = NewGenericError("") + ErrValidation = NewValidationError("") + ErrNotFound = NewRecordNotFoundError("") + ErrMethodDisabled = NewMethodDisabledError("") + ErrGeneric = NewGenericError("") + ErrRecursionTooDeep = errors.New("recursion too deep") ) // ValidationError raised if input data is not valid diff --git a/internal/vfs/azblobfs.go b/internal/vfs/azblobfs.go index 3a23ccce..a11aa870 100644 --- a/internal/vfs/azblobfs.go +++ b/internal/vfs/azblobfs.go @@ -56,6 +56,10 @@ const ( azFolderKey = "hdi_isfolder" ) +var ( + azureBlobDefaultPageSize = int32(5000) +) + // AzureBlobFs is a Fs implementation for Azure Blob storage. type AzureBlobFs struct { connectionID string @@ -308,7 +312,7 @@ func (fs *AzureBlobFs) Rename(source, target string) (int, int64, error) { if err != nil { return -1, -1, err } - return fs.renameInternal(source, target, fi) + return fs.renameInternal(source, target, fi, 0) } // Remove removes the named file or (empty) directory. @@ -408,76 +412,23 @@ func (*AzureBlobFs) Truncate(_ string, _ int64) error { // ReadDir reads the directory named by dirname and returns // a list of directory entries. -func (fs *AzureBlobFs) ReadDir(dirname string) ([]os.FileInfo, error) { - var result []os.FileInfo +func (fs *AzureBlobFs) ReadDir(dirname string) (DirLister, error) { // dirname must be already cleaned prefix := fs.getPrefix(dirname) - - modTimes, err := getFolderModTimes(fs.getStorageID(), dirname) - if err != nil { - return result, err - } - prefixes := make(map[string]bool) - pager := fs.containerClient.NewListBlobsHierarchyPager("/", &container.ListBlobsHierarchyOptions{ Include: container.ListBlobsInclude{ //Metadata: true, }, - Prefix: &prefix, + Prefix: &prefix, + MaxResults: &azureBlobDefaultPageSize, }) - for pager.More() { - ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) - defer cancelFn() - - resp, err := pager.NextPage(ctx) - if err != nil { - metric.AZListObjectsCompleted(err) - return result, err - } - for _, blobPrefix := range resp.ListBlobsHierarchySegmentResponse.Segment.BlobPrefixes { - name := util.GetStringFromPointer(blobPrefix.Name) - // we don't support prefixes == "/" this will be sent if a key starts with "/" - if name == "" || name == "/" { - continue - } - // sometime we have duplicate prefixes, maybe an Azurite bug - name = strings.TrimPrefix(name, prefix) - if _, ok := prefixes[strings.TrimSuffix(name, "/")]; ok { - continue - } - result = append(result, NewFileInfo(name, true, 0, time.Unix(0, 0), false)) - prefixes[strings.TrimSuffix(name, "/")] = true - } - - for _, blobItem := range resp.ListBlobsHierarchySegmentResponse.Segment.BlobItems { - name := util.GetStringFromPointer(blobItem.Name) - name = strings.TrimPrefix(name, prefix) - size := int64(0) - isDir := false - modTime := time.Unix(0, 0) - if blobItem.Properties != nil { - size = util.GetIntFromPointer(blobItem.Properties.ContentLength) - modTime = util.GetTimeFromPointer(blobItem.Properties.LastModified) - contentType := util.GetStringFromPointer(blobItem.Properties.ContentType) - isDir = checkDirectoryMarkers(contentType, blobItem.Metadata) - if isDir { - // check if the dir is already included, it will be sent as blob prefix if it contains at least one item - if _, ok := prefixes[name]; ok { - continue - } - prefixes[name] = true - } - } - if t, ok := modTimes[name]; ok { - modTime = util.GetTimeFromMsecSinceEpoch(t) - } - result = append(result, NewFileInfo(name, isDir, size, modTime, false)) - } - } - metric.AZListObjectsCompleted(nil) - - return result, nil + return &azureBlobDirLister{ + paginator: pager, + timeout: fs.ctxTimeout, + prefix: prefix, + prefixes: make(map[string]bool), + }, nil } // IsUploadResumeSupported returns true if resuming uploads is supported. @@ -569,7 +520,8 @@ func (fs *AzureBlobFs) getFileNamesInPrefix(fsPrefix string) (map[string]bool, e Include: container.ListBlobsInclude{ //Metadata: true, }, - Prefix: &prefix, + Prefix: &prefix, + MaxResults: &azureBlobDefaultPageSize, }) for pager.More() { @@ -615,7 +567,8 @@ func (fs *AzureBlobFs) GetDirSize(dirname string) (int, int64, error) { Include: container.ListBlobsInclude{ Metadata: true, }, - Prefix: &prefix, + Prefix: &prefix, + MaxResults: &azureBlobDefaultPageSize, }) for pager.More() { @@ -684,7 +637,8 @@ func (fs *AzureBlobFs) Walk(root string, walkFn filepath.WalkFunc) error { Include: container.ListBlobsInclude{ Metadata: true, }, - Prefix: &prefix, + Prefix: &prefix, + MaxResults: &azureBlobDefaultPageSize, }) for pager.More() { @@ -863,7 +817,7 @@ func (fs *AzureBlobFs) copyFileInternal(source, target string) error { return nil } -func (fs *AzureBlobFs) renameInternal(source, target string, fi os.FileInfo) (int, int64, error) { +func (fs *AzureBlobFs) renameInternal(source, target string, fi os.FileInfo, recursion int) (int, int64, error) { var numFiles int var filesSize int64 @@ -881,24 +835,12 @@ func (fs *AzureBlobFs) renameInternal(source, target string, fi os.FileInfo) (in return numFiles, filesSize, err } if renameMode == 1 { - entries, err := fs.ReadDir(source) + files, size, err := doRecursiveRename(fs, source, target, fs.renameInternal, recursion) + numFiles += files + filesSize += size if err != nil { return numFiles, filesSize, err } - for _, info := range entries { - sourceEntry := fs.Join(source, info.Name()) - targetEntry := fs.Join(target, info.Name()) - files, size, err := fs.renameInternal(sourceEntry, targetEntry, info) - if err != nil { - if fs.IsNotExist(err) { - fsLog(fs, logger.LevelInfo, "skipping rename for %q: %v", sourceEntry, err) - continue - } - return numFiles, filesSize, err - } - numFiles += files - filesSize += size - } } } else { if err := fs.copyFileInternal(source, target); err != nil { @@ -1312,3 +1254,80 @@ func (b *bufferAllocator) free() { b.available = nil b.finalized = true } + +type azureBlobDirLister struct { + baseDirLister + paginator *runtime.Pager[container.ListBlobsHierarchyResponse] + timeout time.Duration + prefix string + prefixes map[string]bool + metricUpdated bool +} + +func (l *azureBlobDirLister) Next(limit int) ([]os.FileInfo, error) { + if limit <= 0 { + return nil, errInvalidDirListerLimit + } + if len(l.cache) >= limit { + return l.returnFromCache(limit), nil + } + if !l.paginator.More() { + if !l.metricUpdated { + l.metricUpdated = true + metric.AZListObjectsCompleted(nil) + } + return l.returnFromCache(limit), io.EOF + } + ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(l.timeout)) + defer cancelFn() + + page, err := l.paginator.NextPage(ctx) + if err != nil { + metric.AZListObjectsCompleted(err) + return l.cache, err + } + + for _, blobPrefix := range page.ListBlobsHierarchySegmentResponse.Segment.BlobPrefixes { + name := util.GetStringFromPointer(blobPrefix.Name) + // we don't support prefixes == "/" this will be sent if a key starts with "/" + if name == "" || name == "/" { + continue + } + // sometime we have duplicate prefixes, maybe an Azurite bug + name = strings.TrimPrefix(name, l.prefix) + if _, ok := l.prefixes[strings.TrimSuffix(name, "/")]; ok { + continue + } + l.cache = append(l.cache, NewFileInfo(name, true, 0, time.Unix(0, 0), false)) + l.prefixes[strings.TrimSuffix(name, "/")] = true + } + + for _, blobItem := range page.ListBlobsHierarchySegmentResponse.Segment.BlobItems { + name := util.GetStringFromPointer(blobItem.Name) + name = strings.TrimPrefix(name, l.prefix) + size := int64(0) + isDir := false + modTime := time.Unix(0, 0) + if blobItem.Properties != nil { + size = util.GetIntFromPointer(blobItem.Properties.ContentLength) + modTime = util.GetTimeFromPointer(blobItem.Properties.LastModified) + contentType := util.GetStringFromPointer(blobItem.Properties.ContentType) + isDir = checkDirectoryMarkers(contentType, blobItem.Metadata) + if isDir { + // check if the dir is already included, it will be sent as blob prefix if it contains at least one item + if _, ok := l.prefixes[name]; ok { + continue + } + l.prefixes[name] = true + } + } + l.cache = append(l.cache, NewFileInfo(name, isDir, size, modTime, false)) + } + + return l.returnFromCache(limit), nil +} + +func (l *azureBlobDirLister) Close() error { + clear(l.prefixes) + return l.baseDirLister.Close() +} diff --git a/internal/vfs/cryptfs.go b/internal/vfs/cryptfs.go index d8e82cc5..3f20fb37 100644 --- a/internal/vfs/cryptfs.go +++ b/internal/vfs/cryptfs.go @@ -221,21 +221,16 @@ func (*CryptFs) Truncate(_ string, _ int64) error { // ReadDir reads the directory named by dirname and returns // a list of directory entries. -func (fs *CryptFs) ReadDir(dirname string) ([]os.FileInfo, error) { +func (fs *CryptFs) ReadDir(dirname string) (DirLister, error) { f, err := os.Open(dirname) if err != nil { + if isInvalidNameError(err) { + err = os.ErrNotExist + } return nil, err } - list, err := f.Readdir(-1) - f.Close() - if err != nil { - return nil, err - } - result := make([]os.FileInfo, 0, len(list)) - for _, info := range list { - result = append(result, fs.ConvertFileInfo(info)) - } - return result, nil + + return &cryptFsDirLister{f}, nil } // IsUploadResumeSupported returns false sio does not support random access writes @@ -289,20 +284,7 @@ func (fs *CryptFs) getSIOConfig(key [32]byte) sio.Config { // ConvertFileInfo returns a FileInfo with the decrypted size func (fs *CryptFs) ConvertFileInfo(info os.FileInfo) os.FileInfo { - if !info.Mode().IsRegular() { - return info - } - size := info.Size() - if size >= headerV10Size { - size -= headerV10Size - decryptedSize, err := sio.DecryptedSize(uint64(size)) - if err == nil { - size = int64(decryptedSize) - } - } else { - size = 0 - } - return NewFileInfo(info.Name(), info.IsDir(), size, info.ModTime(), false) + return convertCryptFsInfo(info) } func (fs *CryptFs) getFileAndEncryptionKey(name string) (*os.File, [32]byte, error) { @@ -366,6 +348,23 @@ func isZeroBytesDownload(f *os.File, offset int64) (bool, error) { return false, nil } +func convertCryptFsInfo(info os.FileInfo) os.FileInfo { + if !info.Mode().IsRegular() { + return info + } + size := info.Size() + if size >= headerV10Size { + size -= headerV10Size + decryptedSize, err := sio.DecryptedSize(uint64(size)) + if err == nil { + size = int64(decryptedSize) + } + } else { + size = 0 + } + return NewFileInfo(info.Name(), info.IsDir(), size, info.ModTime(), false) +} + type encryptedFileHeader struct { version byte nonce []byte @@ -400,3 +399,22 @@ type cryptedFileWrapper struct { func (w *cryptedFileWrapper) ReadAt(p []byte, offset int64) (n int, err error) { return w.File.ReadAt(p, offset+headerV10Size) } + +type cryptFsDirLister struct { + f *os.File +} + +func (l *cryptFsDirLister) Next(limit int) ([]os.FileInfo, error) { + if limit <= 0 { + return nil, errInvalidDirListerLimit + } + files, err := l.f.Readdir(limit) + for idx := range files { + files[idx] = convertCryptFsInfo(files[idx]) + } + return files, err +} + +func (l *cryptFsDirLister) Close() error { + return l.f.Close() +} diff --git a/internal/vfs/gcsfs.go b/internal/vfs/gcsfs.go index c1803875..0e6c3286 100644 --- a/internal/vfs/gcsfs.go +++ b/internal/vfs/gcsfs.go @@ -266,7 +266,7 @@ func (fs *GCSFs) Rename(source, target string) (int, int64, error) { if err != nil { return -1, -1, err } - return fs.renameInternal(source, target, fi) + return fs.renameInternal(source, target, fi, 0) } // Remove removes the named file or (empty) directory. @@ -369,80 +369,23 @@ func (*GCSFs) Truncate(_ string, _ int64) error { // ReadDir reads the directory named by dirname and returns // a list of directory entries. -func (fs *GCSFs) ReadDir(dirname string) ([]os.FileInfo, error) { - var result []os.FileInfo +func (fs *GCSFs) ReadDir(dirname string) (DirLister, error) { // dirname must be already cleaned prefix := fs.getPrefix(dirname) - query := &storage.Query{Prefix: prefix, Delimiter: "/"} err := query.SetAttrSelection(gcsDefaultFieldsSelection) if err != nil { return nil, err } - - modTimes, err := getFolderModTimes(fs.getStorageID(), dirname) - if err != nil { - return result, err - } - - prefixes := make(map[string]bool) - ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxLongTimeout)) - defer cancelFn() - bkt := fs.svc.Bucket(fs.config.Bucket) - it := bkt.Objects(ctx, query) - pager := iterator.NewPager(it, defaultGCSPageSize, "") - for { - var objects []*storage.ObjectAttrs - pageToken, err := pager.NextPage(&objects) - if err != nil { - metric.GCSListObjectsCompleted(err) - return result, err - } - - for _, attrs := range objects { - if attrs.Prefix != "" { - name, _ := fs.resolve(attrs.Prefix, prefix, attrs.ContentType) - if name == "" { - continue - } - if _, ok := prefixes[name]; ok { - continue - } - result = append(result, NewFileInfo(name, true, 0, time.Unix(0, 0), false)) - prefixes[name] = true - } else { - name, isDir := fs.resolve(attrs.Name, prefix, attrs.ContentType) - if name == "" { - continue - } - if !attrs.Deleted.IsZero() { - continue - } - if isDir { - // check if the dir is already included, it will be sent as blob prefix if it contains at least one item - if _, ok := prefixes[name]; ok { - continue - } - prefixes[name] = true - } - modTime := attrs.Updated - if t, ok := modTimes[name]; ok { - modTime = util.GetTimeFromMsecSinceEpoch(t) - } - result = append(result, NewFileInfo(name, isDir, attrs.Size, modTime, false)) - } - } - - objects = nil - if pageToken == "" { - break - } - } - - metric.GCSListObjectsCompleted(nil) - return result, nil + return &gcsDirLister{ + bucket: bkt, + query: query, + timeout: fs.ctxTimeout, + prefix: prefix, + prefixes: make(map[string]bool), + }, nil } // IsUploadResumeSupported returns true if resuming uploads is supported. @@ -853,7 +796,7 @@ func (fs *GCSFs) copyFileInternal(source, target string) error { return err } -func (fs *GCSFs) renameInternal(source, target string, fi os.FileInfo) (int, int64, error) { +func (fs *GCSFs) renameInternal(source, target string, fi os.FileInfo, recursion int) (int, int64, error) { var numFiles int var filesSize int64 @@ -871,24 +814,12 @@ func (fs *GCSFs) renameInternal(source, target string, fi os.FileInfo) (int, int return numFiles, filesSize, err } if renameMode == 1 { - entries, err := fs.ReadDir(source) + files, size, err := doRecursiveRename(fs, source, target, fs.renameInternal, recursion) + numFiles += files + filesSize += size if err != nil { return numFiles, filesSize, err } - for _, info := range entries { - sourceEntry := fs.Join(source, info.Name()) - targetEntry := fs.Join(target, info.Name()) - files, size, err := fs.renameInternal(sourceEntry, targetEntry, info) - if err != nil { - if fs.IsNotExist(err) { - fsLog(fs, logger.LevelInfo, "skipping rename for %q: %v", sourceEntry, err) - continue - } - return numFiles, filesSize, err - } - numFiles += files - filesSize += size - } } } else { if err := fs.copyFileInternal(source, target); err != nil { @@ -1010,3 +941,97 @@ func (*GCSFs) getTempObject(name string) string { func (fs *GCSFs) getStorageID() string { return fmt.Sprintf("gs://%v", fs.config.Bucket) } + +type gcsDirLister struct { + baseDirLister + bucket *storage.BucketHandle + query *storage.Query + timeout time.Duration + nextPageToken string + noMorePages bool + prefix string + prefixes map[string]bool + metricUpdated bool +} + +func (l *gcsDirLister) resolve(name, contentType string) (string, bool) { + result := strings.TrimPrefix(name, l.prefix) + isDir := strings.HasSuffix(result, "/") + if isDir { + result = strings.TrimSuffix(result, "/") + } + if contentType == dirMimeType { + isDir = true + } + return result, isDir +} + +func (l *gcsDirLister) Next(limit int) ([]os.FileInfo, error) { + if limit <= 0 { + return nil, errInvalidDirListerLimit + } + if len(l.cache) >= limit { + return l.returnFromCache(limit), nil + } + + if l.noMorePages { + if !l.metricUpdated { + l.metricUpdated = true + metric.GCSListObjectsCompleted(nil) + } + return l.returnFromCache(limit), io.EOF + } + + ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(l.timeout)) + defer cancelFn() + + it := l.bucket.Objects(ctx, l.query) + paginator := iterator.NewPager(it, defaultGCSPageSize, l.nextPageToken) + var objects []*storage.ObjectAttrs + + pageToken, err := paginator.NextPage(&objects) + if err != nil { + metric.GCSListObjectsCompleted(err) + return l.cache, err + } + + for _, attrs := range objects { + if attrs.Prefix != "" { + name, _ := l.resolve(attrs.Prefix, attrs.ContentType) + if name == "" { + continue + } + if _, ok := l.prefixes[name]; ok { + continue + } + l.cache = append(l.cache, NewFileInfo(name, true, 0, time.Unix(0, 0), false)) + l.prefixes[name] = true + } else { + name, isDir := l.resolve(attrs.Name, attrs.ContentType) + if name == "" { + continue + } + if !attrs.Deleted.IsZero() { + continue + } + if isDir { + // check if the dir is already included, it will be sent as blob prefix if it contains at least one item + if _, ok := l.prefixes[name]; ok { + continue + } + l.prefixes[name] = true + } + l.cache = append(l.cache, NewFileInfo(name, isDir, attrs.Size, attrs.Updated, false)) + } + } + + l.nextPageToken = pageToken + l.noMorePages = (l.nextPageToken == "") + + return l.returnFromCache(limit), nil +} + +func (l *gcsDirLister) Close() error { + clear(l.prefixes) + return l.baseDirLister.Close() +} diff --git a/internal/vfs/httpfs.go b/internal/vfs/httpfs.go index 67d2c915..331f72dc 100644 --- a/internal/vfs/httpfs.go +++ b/internal/vfs/httpfs.go @@ -488,7 +488,7 @@ func (fs *HTTPFs) Truncate(name string, size int64) error { // ReadDir reads the directory named by dirname and returns // a list of directory entries. -func (fs *HTTPFs) ReadDir(dirname string) ([]os.FileInfo, error) { +func (fs *HTTPFs) ReadDir(dirname string) (DirLister, error) { ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) defer cancelFn() @@ -511,7 +511,7 @@ func (fs *HTTPFs) ReadDir(dirname string) ([]os.FileInfo, error) { for _, stat := range response { result = append(result, stat.getFileInfo()) } - return result, nil + return &baseDirLister{result}, nil } // IsUploadResumeSupported returns true if resuming uploads is supported. @@ -731,19 +731,33 @@ func (fs *HTTPFs) walk(filePath string, info fs.FileInfo, walkFn filepath.WalkFu if !info.IsDir() { return walkFn(filePath, info, nil) } - files, err := fs.ReadDir(filePath) + lister, err := fs.ReadDir(filePath) err1 := walkFn(filePath, info, err) if err != nil || err1 != nil { + if err == nil { + lister.Close() + } return err1 } - for _, fi := range files { - objName := path.Join(filePath, fi.Name()) - err = fs.walk(objName, fi, walkFn) - if err != nil { + defer lister.Close() + + for { + files, err := lister.Next(ListerBatchSize) + finished := errors.Is(err, io.EOF) + if err != nil && !finished { return err } + for _, fi := range files { + objName := path.Join(filePath, fi.Name()) + err = fs.walk(objName, fi, walkFn) + if err != nil { + return err + } + } + if finished { + return nil + } } - return nil } func getErrorFromResponseCode(code int) error { diff --git a/internal/vfs/osfs.go b/internal/vfs/osfs.go index 7f8de472..affc44f4 100644 --- a/internal/vfs/osfs.go +++ b/internal/vfs/osfs.go @@ -190,7 +190,7 @@ func (fs *OsFs) Rename(source, target string) (int, int64, error) { } err = fscopy.Copy(source, target, fscopy.Options{ - OnSymlink: func(src string) fscopy.SymlinkAction { + OnSymlink: func(_ string) fscopy.SymlinkAction { return fscopy.Skip }, CopyBufferSize: readBufferSize, @@ -258,7 +258,7 @@ func (*OsFs) Truncate(name string, size int64) error { // ReadDir reads the directory named by dirname and returns // a list of directory entries. -func (*OsFs) ReadDir(dirname string) ([]os.FileInfo, error) { +func (*OsFs) ReadDir(dirname string) (DirLister, error) { f, err := os.Open(dirname) if err != nil { if isInvalidNameError(err) { @@ -266,12 +266,7 @@ func (*OsFs) ReadDir(dirname string) ([]os.FileInfo, error) { } return nil, err } - list, err := f.Readdir(-1) - f.Close() - if err != nil { - return nil, err - } - return list, nil + return &osFsDirLister{f}, nil } // IsUploadResumeSupported returns true if resuming uploads is supported @@ -599,3 +594,18 @@ func (fs *OsFs) useWriteBuffering(flag int) bool { } return true } + +type osFsDirLister struct { + f *os.File +} + +func (l *osFsDirLister) Next(limit int) ([]os.FileInfo, error) { + if limit <= 0 { + return nil, errInvalidDirListerLimit + } + return l.f.Readdir(limit) +} + +func (l *osFsDirLister) Close() error { + return l.f.Close() +} diff --git a/internal/vfs/s3fs.go b/internal/vfs/s3fs.go index 9b2bdcf9..6a6cd214 100644 --- a/internal/vfs/s3fs.go +++ b/internal/vfs/s3fs.go @@ -22,6 +22,7 @@ import ( "crypto/tls" "errors" "fmt" + "io" "mime" "net" "net/http" @@ -61,7 +62,8 @@ const ( ) var ( - s3DirMimeTypes = []string{s3DirMimeType, "httpd/unix-directory"} + s3DirMimeTypes = []string{s3DirMimeType, "httpd/unix-directory"} + s3DefaultPageSize = int32(5000) ) // S3Fs is a Fs implementation for AWS S3 compatible object storages @@ -337,7 +339,7 @@ func (fs *S3Fs) Rename(source, target string) (int, int64, error) { if err != nil { return -1, -1, err } - return fs.renameInternal(source, target, fi) + return fs.renameInternal(source, target, fi, 0) } // Remove removes the named file or (empty) directory. @@ -426,68 +428,22 @@ func (*S3Fs) Truncate(_ string, _ int64) error { // ReadDir reads the directory named by dirname and returns // a list of directory entries. -func (fs *S3Fs) ReadDir(dirname string) ([]os.FileInfo, error) { - var result []os.FileInfo +func (fs *S3Fs) ReadDir(dirname string) (DirLister, error) { // dirname must be already cleaned prefix := fs.getPrefix(dirname) - - modTimes, err := getFolderModTimes(fs.getStorageID(), dirname) - if err != nil { - return result, err - } - prefixes := make(map[string]bool) - paginator := s3.NewListObjectsV2Paginator(fs.svc, &s3.ListObjectsV2Input{ Bucket: aws.String(fs.config.Bucket), Prefix: aws.String(prefix), Delimiter: aws.String("/"), + MaxKeys: &s3DefaultPageSize, }) - for paginator.HasMorePages() { - ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) - defer cancelFn() - - page, err := paginator.NextPage(ctx) - if err != nil { - metric.S3ListObjectsCompleted(err) - return result, err - } - for _, p := range page.CommonPrefixes { - // prefixes have a trailing slash - name, _ := fs.resolve(p.Prefix, prefix) - if name == "" { - continue - } - if _, ok := prefixes[name]; ok { - continue - } - result = append(result, NewFileInfo(name, true, 0, time.Unix(0, 0), false)) - prefixes[name] = true - } - for _, fileObject := range page.Contents { - objectModTime := util.GetTimeFromPointer(fileObject.LastModified) - objectSize := util.GetIntFromPointer(fileObject.Size) - name, isDir := fs.resolve(fileObject.Key, prefix) - if name == "" || name == "/" { - continue - } - if isDir { - if _, ok := prefixes[name]; ok { - continue - } - prefixes[name] = true - } - if t, ok := modTimes[name]; ok { - objectModTime = util.GetTimeFromMsecSinceEpoch(t) - } - - result = append(result, NewFileInfo(name, (isDir && objectSize == 0), objectSize, - objectModTime, false)) - } - } - - metric.S3ListObjectsCompleted(nil) - return result, nil + return &s3DirLister{ + paginator: paginator, + timeout: fs.ctxTimeout, + prefix: prefix, + prefixes: make(map[string]bool), + }, nil } // IsUploadResumeSupported returns true if resuming uploads is supported. @@ -574,6 +530,7 @@ func (fs *S3Fs) getFileNamesInPrefix(fsPrefix string) (map[string]bool, error) { Bucket: aws.String(fs.config.Bucket), Prefix: aws.String(prefix), Delimiter: aws.String("/"), + MaxKeys: &s3DefaultPageSize, }) for paginator.HasMorePages() { @@ -614,8 +571,9 @@ func (fs *S3Fs) GetDirSize(dirname string) (int, int64, error) { size := int64(0) paginator := s3.NewListObjectsV2Paginator(fs.svc, &s3.ListObjectsV2Input{ - Bucket: aws.String(fs.config.Bucket), - Prefix: aws.String(prefix), + Bucket: aws.String(fs.config.Bucket), + Prefix: aws.String(prefix), + MaxKeys: &s3DefaultPageSize, }) for paginator.HasMorePages() { @@ -679,8 +637,9 @@ func (fs *S3Fs) Walk(root string, walkFn filepath.WalkFunc) error { prefix := fs.getPrefix(root) paginator := s3.NewListObjectsV2Paginator(fs.svc, &s3.ListObjectsV2Input{ - Bucket: aws.String(fs.config.Bucket), - Prefix: aws.String(prefix), + Bucket: aws.String(fs.config.Bucket), + Prefix: aws.String(prefix), + MaxKeys: &s3DefaultPageSize, }) for paginator.HasMorePages() { @@ -797,7 +756,7 @@ func (fs *S3Fs) copyFileInternal(source, target string, fileSize int64) error { return err } -func (fs *S3Fs) renameInternal(source, target string, fi os.FileInfo) (int, int64, error) { +func (fs *S3Fs) renameInternal(source, target string, fi os.FileInfo, recursion int) (int, int64, error) { var numFiles int var filesSize int64 @@ -815,24 +774,12 @@ func (fs *S3Fs) renameInternal(source, target string, fi os.FileInfo) (int, int6 return numFiles, filesSize, err } if renameMode == 1 { - entries, err := fs.ReadDir(source) + files, size, err := doRecursiveRename(fs, source, target, fs.renameInternal, recursion) + numFiles += files + filesSize += size if err != nil { return numFiles, filesSize, err } - for _, info := range entries { - sourceEntry := fs.Join(source, info.Name()) - targetEntry := fs.Join(target, info.Name()) - files, size, err := fs.renameInternal(sourceEntry, targetEntry, info) - if err != nil { - if fs.IsNotExist(err) { - fsLog(fs, logger.LevelInfo, "skipping rename for %q: %v", sourceEntry, err) - continue - } - return numFiles, filesSize, err - } - numFiles += files - filesSize += size - } } } else { if err := fs.copyFileInternal(source, target, fi.Size()); err != nil { @@ -1114,6 +1061,81 @@ func (fs *S3Fs) getStorageID() string { return fmt.Sprintf("s3://%v", fs.config.Bucket) } +type s3DirLister struct { + baseDirLister + paginator *s3.ListObjectsV2Paginator + timeout time.Duration + prefix string + prefixes map[string]bool + metricUpdated bool +} + +func (l *s3DirLister) resolve(name *string) (string, bool) { + result := strings.TrimPrefix(util.GetStringFromPointer(name), l.prefix) + isDir := strings.HasSuffix(result, "/") + if isDir { + result = strings.TrimSuffix(result, "/") + } + return result, isDir +} + +func (l *s3DirLister) Next(limit int) ([]os.FileInfo, error) { + if limit <= 0 { + return nil, errInvalidDirListerLimit + } + if len(l.cache) >= limit { + return l.returnFromCache(limit), nil + } + if !l.paginator.HasMorePages() { + if !l.metricUpdated { + l.metricUpdated = true + metric.S3ListObjectsCompleted(nil) + } + return l.returnFromCache(limit), io.EOF + } + ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(l.timeout)) + defer cancelFn() + + page, err := l.paginator.NextPage(ctx) + if err != nil { + metric.S3ListObjectsCompleted(err) + return l.cache, err + } + for _, p := range page.CommonPrefixes { + // prefixes have a trailing slash + name, _ := l.resolve(p.Prefix) + if name == "" { + continue + } + if _, ok := l.prefixes[name]; ok { + continue + } + l.cache = append(l.cache, NewFileInfo(name, true, 0, time.Unix(0, 0), false)) + l.prefixes[name] = true + } + for _, fileObject := range page.Contents { + objectModTime := util.GetTimeFromPointer(fileObject.LastModified) + objectSize := util.GetIntFromPointer(fileObject.Size) + name, isDir := l.resolve(fileObject.Key) + if name == "" || name == "/" { + continue + } + if isDir { + if _, ok := l.prefixes[name]; ok { + continue + } + l.prefixes[name] = true + } + + l.cache = append(l.cache, NewFileInfo(name, (isDir && objectSize == 0), objectSize, objectModTime, false)) + } + return l.returnFromCache(limit), nil +} + +func (l *s3DirLister) Close() error { + return l.baseDirLister.Close() +} + func getAWSHTTPClient(timeout int, idleConnectionTimeout time.Duration, skipTLSVerify bool) *awshttp.BuildableClient { c := awshttp.NewBuildableClient(). WithDialerOptions(func(d *net.Dialer) { diff --git a/internal/vfs/sftpfs.go b/internal/vfs/sftpfs.go index 929ef7d6..fa606ee5 100644 --- a/internal/vfs/sftpfs.go +++ b/internal/vfs/sftpfs.go @@ -541,12 +541,16 @@ func (fs *SFTPFs) Truncate(name string, size int64) error { // ReadDir reads the directory named by dirname and returns // a list of directory entries. -func (fs *SFTPFs) ReadDir(dirname string) ([]os.FileInfo, error) { +func (fs *SFTPFs) ReadDir(dirname string) (DirLister, error) { client, err := fs.conn.getClient() if err != nil { return nil, err } - return client.ReadDir(dirname) + files, err := client.ReadDir(dirname) + if err != nil { + return nil, err + } + return &baseDirLister{files}, nil } // IsUploadResumeSupported returns true if resuming uploads is supported. diff --git a/internal/vfs/vfs.go b/internal/vfs/vfs.go index 7832ab44..3af5b44b 100644 --- a/internal/vfs/vfs.go +++ b/internal/vfs/vfs.go @@ -45,6 +45,8 @@ const ( gcsfsName = "GCSFs" azBlobFsName = "AzureBlobFs" preResumeTimeout = 90 * time.Second + // ListerBatchSize defines the default limit for DirLister implementations + ListerBatchSize = 1000 ) // Additional checks for files @@ -58,14 +60,15 @@ var ( // ErrStorageSizeUnavailable is returned if the storage backend does not support getting the size ErrStorageSizeUnavailable = errors.New("unable to get available size for this storage backend") // ErrVfsUnsupported defines the error for an unsupported VFS operation - ErrVfsUnsupported = errors.New("not supported") - tempPath string - sftpFingerprints []string - allowSelfConnections int - renameMode int - readMetadata int - resumeMaxSize int64 - uploadMode int + ErrVfsUnsupported = errors.New("not supported") + errInvalidDirListerLimit = errors.New("dir lister: invalid limit, must be > 0") + tempPath string + sftpFingerprints []string + allowSelfConnections int + renameMode int + readMetadata int + resumeMaxSize int64 + uploadMode int ) // SetAllowSelfConnections sets the desired behaviour for self connections @@ -125,7 +128,7 @@ type Fs interface { Chmod(name string, mode os.FileMode) error Chtimes(name string, atime, mtime time.Time, isUploading bool) error Truncate(name string, size int64) error - ReadDir(dirname string) ([]os.FileInfo, error) + ReadDir(dirname string) (DirLister, error) Readlink(name string) (string, error) IsUploadResumeSupported() bool IsConditionalUploadResumeSupported(size int64) bool @@ -199,11 +202,47 @@ type PipeReader interface { Metadata() map[string]string } +// DirLister defines an interface for a directory lister +type DirLister interface { + Next(limit int) ([]os.FileInfo, error) + Close() error +} + // Metadater defines an interface to implement to return metadata for a file type Metadater interface { Metadata() map[string]string } +type baseDirLister struct { + cache []os.FileInfo +} + +func (l *baseDirLister) Next(limit int) ([]os.FileInfo, error) { + if limit <= 0 { + return nil, errInvalidDirListerLimit + } + if len(l.cache) >= limit { + return l.returnFromCache(limit), nil + } + return l.returnFromCache(limit), io.EOF +} + +func (l *baseDirLister) returnFromCache(limit int) []os.FileInfo { + if len(l.cache) >= limit { + result := l.cache[:limit] + l.cache = l.cache[limit:] + return result + } + result := l.cache + l.cache = nil + return result +} + +func (l *baseDirLister) Close() error { + l.cache = nil + return nil +} + // QuotaCheckResult defines the result for a quota check type QuotaCheckResult struct { HasSpace bool @@ -1069,18 +1108,6 @@ func updateFileInfoModTime(storageID, objectPath string, info *FileInfo) (*FileI return info, nil } -func getFolderModTimes(storageID, dirName string) (map[string]int64, error) { - var err error - modTimes := make(map[string]int64) - if plugin.Handler.HasMetadater() { - modTimes, err = plugin.Handler.GetModificationTimes(storageID, ensureAbsPath(dirName)) - if err != nil && !errors.Is(err, metadata.ErrNoSuchObject) { - return modTimes, err - } - } - return modTimes, nil -} - func ensureAbsPath(name string) string { if path.IsAbs(name) { return name @@ -1205,6 +1232,50 @@ func getLocalTempDir() string { return filepath.Clean(os.TempDir()) } +func doRecursiveRename(fs Fs, source, target string, + renameFn func(string, string, os.FileInfo, int) (int, int64, error), + recursion int, +) (int, int64, error) { + var numFiles int + var filesSize int64 + + if recursion > util.MaxRecursion { + return numFiles, filesSize, util.ErrRecursionTooDeep + } + recursion++ + + lister, err := fs.ReadDir(source) + if err != nil { + return numFiles, filesSize, err + } + defer lister.Close() + + for { + entries, err := lister.Next(ListerBatchSize) + finished := errors.Is(err, io.EOF) + if err != nil && !finished { + return numFiles, filesSize, err + } + for _, info := range entries { + sourceEntry := fs.Join(source, info.Name()) + targetEntry := fs.Join(target, info.Name()) + files, size, err := renameFn(sourceEntry, targetEntry, info, recursion) + if err != nil { + if fs.IsNotExist(err) { + fsLog(fs, logger.LevelInfo, "skipping rename for %q: %v", sourceEntry, err) + continue + } + return numFiles, filesSize, err + } + numFiles += files + filesSize += size + } + if finished { + return numFiles, filesSize, nil + } + } +} + func fsLog(fs Fs, level logger.LogLevel, format string, v ...any) { logger.Log(level, fs.Name(), fs.ConnectionID(), format, v...) } diff --git a/internal/webdavd/file.go b/internal/webdavd/file.go index 134aba96..a8529299 100644 --- a/internal/webdavd/file.go +++ b/internal/webdavd/file.go @@ -108,22 +108,24 @@ func (fi *webDavFileInfo) ContentType(_ context.Context) (string, error) { // Readdir reads directory entries from the handle func (f *webDavFile) Readdir(_ int) ([]os.FileInfo, error) { + return nil, webdav.ErrNotImplemented +} + +// ReadDir implements the FileDirLister interface +func (f *webDavFile) ReadDir() (webdav.DirLister, error) { if !f.Connection.User.HasPerm(dataprovider.PermListItems, f.GetVirtualPath()) { return nil, f.Connection.GetPermissionDeniedError() } - entries, err := f.Connection.ListDir(f.GetVirtualPath()) + lister, err := f.Connection.ListDir(f.GetVirtualPath()) if err != nil { return nil, err } - for idx, info := range entries { - entries[idx] = &webDavFileInfo{ - FileInfo: info, - Fs: f.Fs, - virtualPath: path.Join(f.GetVirtualPath(), info.Name()), - fsPath: f.Fs.Join(f.GetFsPath(), info.Name()), - } - } - return entries, nil + return &webDavDirLister{ + DirLister: lister, + fs: f.Fs, + virtualDirPath: f.GetVirtualPath(), + fsDirPath: f.GetFsPath(), + }, nil } // Stat the handle @@ -474,3 +476,24 @@ func (f *webDavFile) Patch(patches []webdav.Proppatch) ([]webdav.Propstat, error } return resp, nil } + +type webDavDirLister struct { + vfs.DirLister + fs vfs.Fs + virtualDirPath string + fsDirPath string +} + +func (l *webDavDirLister) Next(limit int) ([]os.FileInfo, error) { + files, err := l.DirLister.Next(limit) + for idx := range files { + info := files[idx] + files[idx] = &webDavFileInfo{ + FileInfo: info, + Fs: l.fs, + virtualPath: path.Join(l.virtualDirPath, info.Name()), + fsPath: l.fs.Join(l.fsDirPath, info.Name()), + } + } + return files, err +} diff --git a/internal/webdavd/internal_test.go b/internal/webdavd/internal_test.go index 9997d644..0366c7d2 100644 --- a/internal/webdavd/internal_test.go +++ b/internal/webdavd/internal_test.go @@ -692,6 +692,8 @@ func TestContentType(t *testing.T) { assert.Equal(t, "application/custom-mime", ctype) } _, err = davFile.Readdir(-1) + assert.ErrorIs(t, err, webdav.ErrNotImplemented) + _, err = davFile.ReadDir() assert.Error(t, err) err = davFile.Close() assert.NoError(t, err)