webdav: performance improvements and bug fixes

we need my custom golang/x/net/webdav fork for now

https://github.com/drakkan/net/tree/sftpgo
This commit is contained in:
Nicola Murino 2020-11-04 19:11:40 +01:00
parent 442efa0607
commit 0a14297b48
No known key found for this signature in database
GPG key ID: 2F1FB59433D5A8CB
22 changed files with 448 additions and 202 deletions

View file

@ -103,10 +103,16 @@ func init() {
MaxAge: 0, MaxAge: 0,
}, },
Cache: webdavd.Cache{ Cache: webdavd.Cache{
Users: webdavd.UsersCacheConfig{
Enabled: true, Enabled: true,
ExpirationTime: 0, ExpirationTime: 0,
MaxSize: 50, MaxSize: 50,
}, },
MimeTypes: webdavd.MimeCacheConfig{
Enabled: true,
MaxSize: 1000,
},
},
}, },
ProviderConf: dataprovider.Config{ ProviderConf: dataprovider.Config{
Driver: "sqlite", Driver: "sqlite",
@ -393,9 +399,11 @@ func setViperDefaults() {
viper.SetDefault("webdavd.cors.exposed_headers", globalConf.WebDAVD.Cors.ExposedHeaders) viper.SetDefault("webdavd.cors.exposed_headers", globalConf.WebDAVD.Cors.ExposedHeaders)
viper.SetDefault("webdavd.cors.allow_credentials", globalConf.WebDAVD.Cors.AllowCredentials) viper.SetDefault("webdavd.cors.allow_credentials", globalConf.WebDAVD.Cors.AllowCredentials)
viper.SetDefault("webdavd.cors.max_age", globalConf.WebDAVD.Cors.MaxAge) viper.SetDefault("webdavd.cors.max_age", globalConf.WebDAVD.Cors.MaxAge)
viper.SetDefault("webdavd.cache.enabled", globalConf.WebDAVD.Cache.Enabled) viper.SetDefault("webdavd.cache.users.enabled", globalConf.WebDAVD.Cache.Users.Enabled)
viper.SetDefault("webdavd.cache.expiration_time", globalConf.WebDAVD.Cache.ExpirationTime) viper.SetDefault("webdavd.cache.users.expiration_time", globalConf.WebDAVD.Cache.Users.ExpirationTime)
viper.SetDefault("webdavd.cache.max_size", globalConf.WebDAVD.Cache.MaxSize) viper.SetDefault("webdavd.cache.users.max_size", globalConf.WebDAVD.Cache.Users.MaxSize)
viper.SetDefault("webdavd.cache.mime_types.enabled", globalConf.WebDAVD.Cache.MimeTypes.Enabled)
viper.SetDefault("webdavd.cache.mime_types.max_size", globalConf.WebDAVD.Cache.MimeTypes.MaxSize)
viper.SetDefault("data_provider.driver", globalConf.ProviderConf.Driver) viper.SetDefault("data_provider.driver", globalConf.ProviderConf.Driver)
viper.SetDefault("data_provider.name", globalConf.ProviderConf.Name) viper.SetDefault("data_provider.name", globalConf.ProviderConf.Name)
viper.SetDefault("data_provider.host", globalConf.ProviderConf.Host) viper.SetDefault("data_provider.host", globalConf.ProviderConf.Host)

View file

@ -1993,7 +1993,7 @@ func updateVFoldersQuotaAfterRestore(foldersToScan []string) {
} }
// CacheWebDAVUser add a user to the WebDAV cache // CacheWebDAVUser add a user to the WebDAV cache
func CacheWebDAVUser(cachedUser CachedUser, maxSize int) { func CacheWebDAVUser(cachedUser *CachedUser, maxSize int) {
if maxSize > 0 { if maxSize > 0 {
var cacheSize int var cacheSize int
var userToRemove string var userToRemove string
@ -2003,10 +2003,10 @@ func CacheWebDAVUser(cachedUser CachedUser, maxSize int) {
cacheSize++ cacheSize++
if len(userToRemove) == 0 { if len(userToRemove) == 0 {
userToRemove = k.(string) userToRemove = k.(string)
expirationTime = v.(CachedUser).Expiration expirationTime = v.(*CachedUser).Expiration
return true return true
} }
expireTime := v.(CachedUser).Expiration expireTime := v.(*CachedUser).Expiration
if !expireTime.IsZero() && expireTime.Before(expirationTime) { if !expireTime.IsZero() && expireTime.Before(expirationTime) {
userToRemove = k.(string) userToRemove = k.(string)
expirationTime = expireTime expirationTime = expireTime

View file

@ -12,6 +12,8 @@ import (
"strings" "strings"
"time" "time"
"golang.org/x/net/webdav"
"github.com/drakkan/sftpgo/logger" "github.com/drakkan/sftpgo/logger"
"github.com/drakkan/sftpgo/utils" "github.com/drakkan/sftpgo/utils"
"github.com/drakkan/sftpgo/vfs" "github.com/drakkan/sftpgo/vfs"
@ -65,10 +67,11 @@ type CachedUser struct {
User User User User
Expiration time.Time Expiration time.Time
Password string Password string
LockSystem webdav.LockSystem
} }
// IsExpired returns true if the cached user is expired // IsExpired returns true if the cached user is expired
func (c CachedUser) IsExpired() bool { func (c *CachedUser) IsExpired() bool {
if c.Expiration.IsZero() { if c.Expiration.IsZero() {
return false return false
} }

View file

@ -74,6 +74,14 @@ These properties are stored inside the data provider.
If you want to use your existing accounts, you have these options: If you want to use your existing accounts, you have these options:
- If your accounts are already stored inside a supported database, you can create a database view. Since a view is read only, you have to disable user management and quota tracking so SFTPGo will never try to write to the view
- you can import your users inside SFTPGo. Take a look at [sftpgo_api_cli](../examples/rest-api-cli#convert-users-from-other-stores "SFTPGo API CLI example"), it can convert and import users from Linux system users and Pure-FTPd/ProFTPD virtual users - you can import your users inside SFTPGo. Take a look at [sftpgo_api_cli](../examples/rest-api-cli#convert-users-from-other-stores "SFTPGo API CLI example"), it can convert and import users from Linux system users and Pure-FTPd/ProFTPD virtual users
- you can use an external authentication program - you can use an external authentication program
Please take a look at the [OpenAPI schema](../httpd/schema/openapi.yaml) for the exact definitions of user and folder fields.
If you need an example you can export a dump using the REST API CLI client or by invoking the `dumpdata` endpoint directly, for example:
```shell
curl "http://127.0.0.1:8080/api/v1/dumpdata?output_file=dump.json&indent=1"
```
the dump is a JSON with users and folder.

View file

@ -7,19 +7,23 @@ Each user has his own path like `http/s://<SFTPGo ip>:<WevDAVPORT>/<username>` a
WebDAV is quite a different protocol than SCP/FTP, there is no session concept, each command is a separate HTTP request and must be authenticated, performance can be greatly improved enabling caching for the authenticated users (it is enabled by default). This way SFTPGo don't need to do a dataprovider query and a password check for each request. WebDAV is quite a different protocol than SCP/FTP, there is no session concept, each command is a separate HTTP request and must be authenticated, performance can be greatly improved enabling caching for the authenticated users (it is enabled by default). This way SFTPGo don't need to do a dataprovider query and a password check for each request.
If you enable quota support a dataprovider query is required, to update the user quota, after each file upload. If you enable quota support a dataprovider query is required, to update the user quota, after each file upload.
The caching configuration allows to set: The user caching configuration allows to set:
- `expiration_time` in minutes. If a user is cached for more than the specified minutes it will be removed from the cache and a new dataprovider query will be performed. Please note that the `last_login` field will not be updated and `external_auth_hook`, `pre_login_hook` and `check_password_hook` will not be executed if the user is obtained from the cache. - `expiration_time` in minutes. If a user is cached for more than the specified minutes it will be removed from the cache and a new dataprovider query will be performed. Please note that the `last_login` field will not be updated and `external_auth_hook`, `pre_login_hook` and `check_password_hook` will not be executed if the user is obtained from the cache.
- `max_size`. Maximum number of users to cache. When this limit is reached the user with the oldest expiration date will be removed from the cache. 0 means no limit however the cache size cannot exceed the number of users so if you have a small number of users you can leave this setting to 0. - `max_size`. Maximum number of users to cache. When this limit is reached the user with the oldest expiration date will be removed from the cache. 0 means no limit however the cache size cannot exceed the number of users so if you have a small number of users you can leave this setting to 0.
Users are automatically removed from the cache after an update/delete. Users are automatically removed from the cache after an update/delete.
WebDAV protocol requires the MIME type for each file. SFTPGo will first try to guess the MIME type by extension. If this fails it will send a `HEAD` request for Cloud backends and, as last resort, it will try to guess the MIME type reading the first 512 bytes of the file. This may slow down the directory listing, especially for Cloud based backends, if you have directories containing many files with unregistered extensions. To mitigate this problem, you can enable caching of MIME types so that the MIME type detection is done only once.
The MIME types caching configurations allows to set the maximum number of MIME types to cache. Once the cache reaches the configured maximum size no new MIME types will be added. The MIME types cache is a non-persistent in-memory cache. If you need a persistent cache add your MIME types to `/etc/mime.types` on Linux or inside the registry on Windows.
WebDAV should work as expected for most use cases but there are some minor issues and some missing features. WebDAV should work as expected for most use cases but there are some minor issues and some missing features.
Know issues: Know issues:
- removing a directory tree on Cloud Storage backends could generate a `not found` error when removing the last (virtual) directory. This happens if the client cycles the directories tree itself and removes files and directories one by one instead of issuing a single remove command - removing a directory tree on Cloud Storage backends could generate a `not found` error when removing the last (virtual) directory. This happens if the client cycles the directories tree itself and removes files and directories one by one instead of issuing a single remove command
- the used [WebDAV library](https://pkg.go.dev/golang.org/x/net/webdav?tab=doc) asks to open a file to execute a `stat` and sometimes reads some bytes to find the content type. We are unable to distinguish a `stat` from a `download` for now, so to be able to properly list a directory you need to grant both `list` and `download` permissions - the used [WebDAV library](https://pkg.go.dev/golang.org/x/net/webdav?tab=doc) asks to open a file to execute a `stat` and sometimes reads some bytes to find the content type. Stat calls are executed before and after a download too, so to be able to properly list a directory you need to grant both `list` and `download` permissions and to be able to upload files you need to gran both `list` and `upload` permissions
- the used `WebDAV library` not always returns a proper error code/message, most of the times it simply returns `Method not Allowed`. I'll try to improve the library error codes in the future - the used `WebDAV library` not always returns a proper error code/message, most of the times it simply returns `Method not Allowed`. I'll try to improve the library error codes in the future
- if an object within a directory cannot be accessed, for example due to OS permissions issues or because is a missing mapped path for a virtual folder, the directory listing will fail. In SFTP/FTP the directory listing will succeed and you'll only get an error if you try to access to the problematic file/directory - if an object within a directory cannot be accessed, for example due to OS permissions issues or because is a missing mapped path for a virtual folder, the directory listing will fail. In SFTP/FTP the directory listing will succeed and you'll only get an error if you try to access to the problematic file/directory

2
go.mod
View file

@ -57,5 +57,5 @@ require (
replace ( replace (
github.com/jlaffaye/ftp => github.com/drakkan/ftp v0.0.0-20200730125632-b21eac28818c github.com/jlaffaye/ftp => github.com/drakkan/ftp v0.0.0-20200730125632-b21eac28818c
golang.org/x/crypto => github.com/drakkan/crypto v0.0.0-20201017144935-4e8324213ac3 golang.org/x/crypto => github.com/drakkan/crypto v0.0.0-20201017144935-4e8324213ac3
golang.org/x/net => github.com/drakkan/net v0.0.0-20201101072345-49fbbaa64b66 golang.org/x/net => github.com/drakkan/net v0.0.0-20201104142514-34ad2afe5beb
) )

6
go.sum
View file

@ -121,6 +121,12 @@ github.com/drakkan/ftp v0.0.0-20200730125632-b21eac28818c h1:QSXIWohSNn0negBVSKE
github.com/drakkan/ftp v0.0.0-20200730125632-b21eac28818c/go.mod h1:2lmrmq866uF2tnje75wQHzmPXhmSWUt7Gyx2vgK1RCU= github.com/drakkan/ftp v0.0.0-20200730125632-b21eac28818c/go.mod h1:2lmrmq866uF2tnje75wQHzmPXhmSWUt7Gyx2vgK1RCU=
github.com/drakkan/net v0.0.0-20201101072345-49fbbaa64b66 h1:Y92YgfaycEmjy9L6CY633pCrxGtAlV3wh5n4vS7U+os= github.com/drakkan/net v0.0.0-20201101072345-49fbbaa64b66 h1:Y92YgfaycEmjy9L6CY633pCrxGtAlV3wh5n4vS7U+os=
github.com/drakkan/net v0.0.0-20201101072345-49fbbaa64b66/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= github.com/drakkan/net v0.0.0-20201101072345-49fbbaa64b66/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
github.com/drakkan/net v0.0.0-20201104095909-9ef94e8aecdc h1:NWOlujrNWzUkQLn+ESgRw/w1Kr7piY2XMn7J7CylHA0=
github.com/drakkan/net v0.0.0-20201104095909-9ef94e8aecdc/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
github.com/drakkan/net v0.0.0-20201104141241-5c1fd0e3eb3e h1:Ke7I1E2awH+aj9S3xSiV7NmEpnjJ/M25+juJl0P8OJo=
github.com/drakkan/net v0.0.0-20201104141241-5c1fd0e3eb3e/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
github.com/drakkan/net v0.0.0-20201104142514-34ad2afe5beb h1:NgZ7GvppCYwS8iG+zcuQvzVCAvTQLtxVe7PSWxJtxhI=
github.com/drakkan/net v0.0.0-20201104142514-34ad2afe5beb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs= github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs=
github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU=

View file

@ -1,6 +1,6 @@
#!/bin/bash #!/bin/bash
NFPM_VERSION=1.8.0 NFPM_VERSION=1.9.0
if [ -z ${SFTPGO_VERSION} ] if [ -z ${SFTPGO_VERSION} ]
then then
@ -94,7 +94,7 @@ EOF
curl --retry 5 --retry-delay 2 --connect-timeout 10 -L -O \ curl --retry 5 --retry-delay 2 --connect-timeout 10 -L -O \
https://github.com/goreleaser/nfpm/releases/download/v${NFPM_VERSION}/nfpm_${NFPM_VERSION}_Linux_x86_64.tar.gz https://github.com/goreleaser/nfpm/releases/download/v${NFPM_VERSION}/nfpm_${NFPM_VERSION}_Linux_x86_64.tar.gz
tar xvf nfpm_1.8.0_Linux_x86_64.tar.gz nfpm tar xvf nfpm_${NFPM_VERSION}_Linux_x86_64.tar.gz nfpm
chmod 755 nfpm chmod 755 nfpm
mkdir deb mkdir deb
./nfpm -f nfpm.yaml pkg -p deb -t deb ./nfpm -f nfpm.yaml pkg -p deb -t deb

View file

@ -65,8 +65,8 @@ func (s *Service) Start() error {
} }
} }
logger.Info(logSender, "", "starting SFTPGo %v, config dir: %v, config file: %v, log max size: %v log max backups: %v "+ logger.Info(logSender, "", "starting SFTPGo %v, config dir: %v, config file: %v, log max size: %v log max backups: %v "+
"log max age: %v log verbose: %v, log compress: %v, profile: %v", version.GetAsString(), s.ConfigDir, s.ConfigFile, "log max age: %v log verbose: %v, log compress: %v, profile: %v load data from: %#v", version.GetAsString(), s.ConfigDir, s.ConfigFile,
s.LogMaxSize, s.LogMaxBackups, s.LogMaxAge, s.LogVerbose, s.LogCompress, s.Profiler) s.LogMaxSize, s.LogMaxBackups, s.LogMaxAge, s.LogVerbose, s.LogCompress, s.Profiler, s.LoadDataFrom)
// in portable mode we don't read configuration from file // in portable mode we don't read configuration from file
if s.PortableMode != 1 { if s.PortableMode != 1 {
err := config.LoadConfig(s.ConfigDir, s.ConfigFile) err := config.LoadConfig(s.ConfigDir, s.ConfigFile)
@ -221,8 +221,10 @@ func (s *Service) loadInitialData() error {
if err != nil { if err != nil {
return fmt.Errorf("unable to restore users from file %#v: %v", s.LoadDataFrom, err) return fmt.Errorf("unable to restore users from file %#v: %v", s.LoadDataFrom, err)
} }
logger.Info(logSender, "", "data loaded from file %#v", s.LoadDataFrom) logger.Info(logSender, "", "data loaded from file %#v mode: %v, quota scan %v", s.LoadDataFrom,
logger.InfoToConsole("data loaded from file %#v", s.LoadDataFrom) s.LoadDataMode, s.LoadDataQuotaScan)
logger.InfoToConsole("data loaded from file %#v mode: %v, quota scan %v", s.LoadDataFrom,
s.LoadDataMode, s.LoadDataQuotaScan)
if s.LoadDataClean { if s.LoadDataClean {
err = os.Remove(s.LoadDataFrom) err = os.Remove(s.LoadDataFrom)
if err == nil { if err == nil {

View file

@ -62,9 +62,15 @@
"max_age": 0 "max_age": 0
}, },
"cache": { "cache": {
"users": {
"enabled": true, "enabled": true,
"expiration_time": 0, "expiration_time": 0,
"max_size": 50 "max_size": 50
},
"mime_types": {
"enabled": true,
"max_size": 1000
}
} }
}, },
"data_provider": { "data_provider": {

View file

@ -681,7 +681,7 @@ func (fs *AzureBlobFs) headObject(name string) (*azblob.BlobGetPropertiesRespons
return response, err return response, err
} }
// GetMimeType implements MimeTyper interface // GetMimeType returns the content type
func (fs *AzureBlobFs) GetMimeType(name string) (string, error) { func (fs *AzureBlobFs) GetMimeType(name string) (string, error) {
response, err := fs.headObject(name) response, err := fs.headObject(name)
if err != nil { if err != nil {

View file

@ -709,7 +709,7 @@ func (fs *GCSFs) headObject(name string) (*storage.ObjectAttrs, error) {
return attrs, err return attrs, err
} }
// GetMimeType implements MimeTyper interface // GetMimeType returns the content type
func (fs *GCSFs) GetMimeType(name string) (string, error) { func (fs *GCSFs) GetMimeType(name string) (string, error) {
attrs, err := fs.headObject(name) attrs, err := fs.headObject(name)
if err != nil { if err != nil {

View file

@ -2,6 +2,8 @@ package vfs
import ( import (
"fmt" "fmt"
"io"
"net/http"
"os" "os"
"path" "path"
"path/filepath" "path/filepath"
@ -445,3 +447,21 @@ func (fs *OsFs) createMissingDirs(filePath string, uid, gid int) error {
} }
return nil return nil
} }
// GetMimeType returns the content type
func (fs *OsFs) GetMimeType(name string) (string, error) {
f, err := os.OpenFile(name, os.O_RDONLY, 0)
if err != nil {
return "", err
}
defer f.Close()
var buf [512]byte
n, err := io.ReadFull(f, buf[:])
if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF {
return "", err
}
ctype := http.DetectContentType(buf[:n])
// Rewind file.
_, err = f.Seek(0, io.SeekStart)
return ctype, err
}

View file

@ -674,7 +674,7 @@ func (fs *S3Fs) headObject(name string) (*s3.HeadObjectOutput, error) {
return obj, err return obj, err
} }
// GetMimeType implements MimeTyper interface // GetMimeType returns the content type
func (fs *S3Fs) GetMimeType(name string) (string, error) { func (fs *S3Fs) GetMimeType(name string) (string, error) {
obj, err := fs.headObject(name) obj, err := fs.headObject(name)
if err != nil { if err != nil {

View file

@ -53,10 +53,6 @@ type Fs interface {
Walk(root string, walkFn filepath.WalkFunc) error Walk(root string, walkFn filepath.WalkFunc) error
Join(elem ...string) string Join(elem ...string) string
HasVirtualFolders() bool HasVirtualFolders() bool
}
// MimeTyper defines an optional interface to get the content type
type MimeTyper interface {
GetMimeType(name string) (string, error) GetMimeType(name string) (string, error)
} }

View file

@ -14,6 +14,8 @@ import (
"golang.org/x/net/webdav" "golang.org/x/net/webdav"
"github.com/drakkan/sftpgo/common" "github.com/drakkan/sftpgo/common"
"github.com/drakkan/sftpgo/dataprovider"
"github.com/drakkan/sftpgo/logger"
"github.com/drakkan/sftpgo/vfs" "github.com/drakkan/sftpgo/vfs"
) )
@ -23,13 +25,13 @@ type webDavFile struct {
*common.BaseTransfer *common.BaseTransfer
writer io.WriteCloser writer io.WriteCloser
reader io.ReadCloser reader io.ReadCloser
isFinished bool
startOffset int64
info os.FileInfo info os.FileInfo
startOffset int64
isFinished bool
readTryed int32
} }
func newWebDavFile(baseTransfer *common.BaseTransfer, pipeWriter *vfs.PipeWriter, pipeReader *pipeat.PipeReaderAt, func newWebDavFile(baseTransfer *common.BaseTransfer, pipeWriter *vfs.PipeWriter, pipeReader *pipeat.PipeReaderAt) *webDavFile {
info os.FileInfo) *webDavFile {
var writer io.WriteCloser var writer io.WriteCloser
var reader io.ReadCloser var reader io.ReadCloser
if baseTransfer.File != nil { if baseTransfer.File != nil {
@ -46,65 +48,86 @@ func newWebDavFile(baseTransfer *common.BaseTransfer, pipeWriter *vfs.PipeWriter
reader: reader, reader: reader,
isFinished: false, isFinished: false,
startOffset: 0, startOffset: 0,
info: info, info: nil,
readTryed: 0,
} }
} }
type webDavFileInfo struct { type webDavFileInfo struct {
os.FileInfo os.FileInfo
file *webDavFile Fs vfs.Fs
virtualPath string
fsPath string
} }
// ContentType implements webdav.ContentTyper interface // ContentType implements webdav.ContentTyper interface
func (fi webDavFileInfo) ContentType(ctx context.Context) (string, error) { func (fi *webDavFileInfo) ContentType(ctx context.Context) (string, error) {
contentType := mime.TypeByExtension(path.Ext(fi.file.GetVirtualPath())) extension := path.Ext(fi.virtualPath)
contentType := mime.TypeByExtension(extension)
if contentType != "" { if contentType != "" {
return contentType, nil return contentType, nil
} }
if c, ok := fi.file.Fs.(vfs.MimeTyper); ok { contentType = mimeTypeCache.getMimeFromCache(extension)
contentType, err := c.GetMimeType(fi.file.GetFsPath()) if contentType != "" {
return contentType, nil
}
contentType, err := fi.Fs.GetMimeType(fi.fsPath)
mimeTypeCache.addMimeToCache(extension, contentType)
if contentType != "" {
return contentType, err return contentType, err
} }
return contentType, webdav.ErrNotImplemented return "", webdav.ErrNotImplemented
} }
// Readdir reads directory entries from the handle // Readdir reads directory entries from the handle
func (f *webDavFile) Readdir(count int) ([]os.FileInfo, error) { func (f *webDavFile) Readdir(count int) ([]os.FileInfo, error) {
if f.isDir() { if !f.Connection.User.HasPerm(dataprovider.PermListItems, f.GetVirtualPath()) {
return f.Connection.ListDir(f.GetFsPath(), f.GetVirtualPath()) return nil, f.Connection.GetPermissionDeniedError()
} }
return nil, errors.New("we can only list directories contents, this is not a directory") fileInfos, err := f.Connection.ListDir(f.GetFsPath(), f.GetVirtualPath())
if err != nil {
return nil, err
}
result := make([]os.FileInfo, 0, len(fileInfos))
for _, fileInfo := range fileInfos {
result = append(result, &webDavFileInfo{
FileInfo: fileInfo,
Fs: f.Fs,
virtualPath: path.Join(f.GetVirtualPath(), fileInfo.Name()),
fsPath: f.Fs.Join(f.GetFsPath(), fileInfo.Name()),
})
}
return result, nil
} }
// Stat the handle // Stat the handle
func (f *webDavFile) Stat() (os.FileInfo, error) { func (f *webDavFile) Stat() (os.FileInfo, error) {
if f.info != nil { if f.GetType() == common.TransferDownload && !f.Connection.User.HasPerm(dataprovider.PermListItems, path.Dir(f.GetVirtualPath())) {
fi := webDavFileInfo{ return nil, f.Connection.GetPermissionDeniedError()
FileInfo: f.info,
file: f,
}
return fi, nil
} }
f.Lock() f.Lock()
closed := f.isFinished
errUpload := f.ErrTransfer errUpload := f.ErrTransfer
f.Unlock() f.Unlock()
if f.GetType() == common.TransferUpload && closed && errUpload == nil { if f.GetType() == common.TransferUpload && errUpload == nil {
info := webDavFileInfo{ info := &webDavFileInfo{
FileInfo: vfs.NewFileInfo(f.GetFsPath(), false, atomic.LoadInt64(&f.BytesReceived), time.Now(), false), FileInfo: vfs.NewFileInfo(f.GetFsPath(), false, atomic.LoadInt64(&f.BytesReceived), time.Now(), false),
file: f, Fs: f.Fs,
virtualPath: f.GetVirtualPath(),
fsPath: f.GetFsPath(),
} }
return info, nil return info, nil
} }
info, err := f.Fs.Stat(f.GetFsPath()) info, err := f.Fs.Stat(f.GetFsPath())
if err != nil { if err != nil {
return info, err return nil, err
} }
fi := webDavFileInfo{ fi := &webDavFileInfo{
FileInfo: info, FileInfo: info,
file: f, Fs: f.Fs,
virtualPath: f.GetVirtualPath(),
fsPath: f.GetFsPath(),
} }
return fi, err return fi, nil
} }
// Read reads the contents to downloads. // Read reads the contents to downloads.
@ -112,6 +135,18 @@ func (f *webDavFile) Read(p []byte) (n int, err error) {
if atomic.LoadInt32(&f.AbortTransfer) == 1 { if atomic.LoadInt32(&f.AbortTransfer) == 1 {
return 0, errTransferAborted return 0, errTransferAborted
} }
if atomic.LoadInt32(&f.readTryed) == 0 {
atomic.StoreInt32(&f.readTryed, 1)
if !f.Connection.User.HasPerm(dataprovider.PermDownload, path.Dir(f.GetVirtualPath())) {
return 0, f.Connection.GetPermissionDeniedError()
}
if !f.Connection.User.IsFileAllowed(f.GetVirtualPath()) {
f.Connection.Log(logger.LevelWarn, "reading file %#v is not allowed", f.GetVirtualPath())
return 0, f.Connection.GetPermissionDeniedError()
}
}
f.Connection.UpdateLastActivity() f.Connection.UpdateLastActivity()
@ -167,6 +202,18 @@ func (f *webDavFile) Write(p []byte) (n int, err error) {
return return
} }
func (f *webDavFile) updateStatInfo() error {
if f.info != nil {
return nil
}
info, err := f.Fs.Stat(f.GetFsPath())
if err != nil {
return err
}
f.info = info
return nil
}
// Seek sets the offset for the next Read or Write on the writer to offset, // Seek sets the offset for the next Read or Write on the writer to offset,
// interpreted according to whence: 0 means relative to the origin of the file, // interpreted according to whence: 0 means relative to the origin of the file,
// 1 means relative to the current offset, and 2 means relative to the end. // 1 means relative to the current offset, and 2 means relative to the end.
@ -185,7 +232,10 @@ func (f *webDavFile) Seek(offset int64, whence int) (int64, error) {
if offset == 0 && readOffset == 0 { if offset == 0 && readOffset == 0 {
if whence == io.SeekStart { if whence == io.SeekStart {
return 0, nil return 0, nil
} else if whence == io.SeekEnd && f.info != nil { } else if whence == io.SeekEnd {
if err := f.updateStatInfo(); err != nil {
return 0, err
}
return f.info.Size(), nil return f.info.Size(), nil
} }
} }
@ -204,13 +254,11 @@ func (f *webDavFile) Seek(offset int64, whence int) (int64, error) {
case io.SeekCurrent: case io.SeekCurrent:
startByte = readOffset + offset startByte = readOffset + offset
case io.SeekEnd: case io.SeekEnd:
if f.info != nil { if err := f.updateStatInfo(); err != nil {
startByte = f.info.Size() - offset
} else {
err := errors.New("unable to get file size, seek from end not possible")
f.TransferError(err) f.TransferError(err)
return 0, err return 0, err
} }
startByte = f.info.Size() - offset
} }
_, r, cancelFn, err := f.Fs.Open(f.GetFsPath(), startByte) _, r, cancelFn, err := f.Fs.Open(f.GetFsPath(), startByte)
@ -274,16 +322,9 @@ func (f *webDavFile) setFinished() error {
return nil return nil
} }
func (f *webDavFile) isDir() bool {
if f.info == nil {
return false
}
return f.info.IsDir()
}
func (f *webDavFile) isTransfer() bool { func (f *webDavFile) isTransfer() bool {
if f.GetType() == common.TransferDownload { if f.GetType() == common.TransferDownload {
return (f.reader != nil) return atomic.LoadInt32(&f.readTryed) > 0
} }
return true return true
} }

View file

@ -144,52 +144,20 @@ func (c *Connection) OpenFile(ctx context.Context, name string, flag int, perm o
return nil, c.GetFsError(err) return nil, c.GetFsError(err)
} }
if flag == os.O_RDONLY { if flag == os.O_RDONLY {
// Download, Stat or Readdir // Download, Stat, Readdir or simply open/close
fi, err := c.Fs.Lstat(p) return c.getFile(p, name)
if err != nil {
return nil, c.GetFsError(err)
}
return c.getFile(p, name, fi)
} }
return c.putFile(p, name) return c.putFile(p, name)
} }
func (c *Connection) getFile(fsPath, virtualPath string, info os.FileInfo) (webdav.File, error) { func (c *Connection) getFile(fsPath, virtualPath string) (webdav.File, error) {
var err error var err error
if info.IsDir() {
if !c.User.HasPerm(dataprovider.PermListItems, virtualPath) {
return nil, c.GetPermissionDeniedError()
}
var file *os.File
if vfs.IsLocalOsFs(c.Fs) {
file, _, _, err = c.Fs.Open(fsPath, 0)
if err != nil {
c.Log(logger.LevelWarn, "could not open directory %#v for reading: %+v", fsPath, err)
return nil, c.GetFsError(err)
}
}
baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, nil, fsPath, virtualPath, common.TransferDownload,
0, 0, 0, false, c.Fs)
return newWebDavFile(baseTransfer, nil, nil, info), nil
}
// we don't know if the file will be downloaded or opened for get properties so we check both permissions
if !c.User.HasPerms([]string{dataprovider.PermDownload, dataprovider.PermListItems}, path.Dir(virtualPath)) {
return nil, c.GetPermissionDeniedError()
}
if !c.User.IsFileAllowed(virtualPath) {
c.Log(logger.LevelWarn, "reading file %#v is not allowed", virtualPath)
return nil, c.GetPermissionDeniedError()
}
var file *os.File var file *os.File
var r *pipeat.PipeReaderAt var r *pipeat.PipeReaderAt
var cancelFn func() var cancelFn func()
// for cloud fs we open the file when we receive the first read to avoid to download the first part of // for cloud fs we open the file when we receive the first read to avoid to download the first part of
// the file if it was opened to get stats and not for a real download // the file if it was opened only to do a stat or a readdir and so it ins't a download
if vfs.IsLocalOsFs(c.Fs) { if vfs.IsLocalOsFs(c.Fs) {
file, r, cancelFn, err = c.Fs.Open(fsPath, 0) file, r, cancelFn, err = c.Fs.Open(fsPath, 0)
if err != nil { if err != nil {
@ -201,7 +169,7 @@ func (c *Connection) getFile(fsPath, virtualPath string, info os.FileInfo) (webd
baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, fsPath, virtualPath, common.TransferDownload, baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, fsPath, virtualPath, common.TransferDownload,
0, 0, 0, false, c.Fs) 0, 0, 0, false, c.Fs)
return newWebDavFile(baseTransfer, nil, r, info), nil return newWebDavFile(baseTransfer, nil, r), nil
} }
func (c *Connection) putFile(fsPath, virtualPath string) (webdav.File, error) { func (c *Connection) putFile(fsPath, virtualPath string) (webdav.File, error) {
@ -261,7 +229,7 @@ func (c *Connection) handleUploadToNewFile(resolvedPath, filePath, requestPath s
baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, requestPath, baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, requestPath,
common.TransferUpload, 0, 0, maxWriteSize, true, c.Fs) common.TransferUpload, 0, 0, maxWriteSize, true, c.Fs)
return newWebDavFile(baseTransfer, w, nil, nil), nil return newWebDavFile(baseTransfer, w, nil), nil
} }
func (c *Connection) handleUploadToExistingFile(resolvedPath, filePath string, fileSize int64, func (c *Connection) handleUploadToExistingFile(resolvedPath, filePath string, fileSize int64,
@ -311,7 +279,7 @@ func (c *Connection) handleUploadToExistingFile(resolvedPath, filePath string, f
baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, requestPath, baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, requestPath,
common.TransferUpload, 0, initialSize, maxWriteSize, false, c.Fs) common.TransferUpload, 0, initialSize, maxWriteSize, false, c.Fs)
return newWebDavFile(baseTransfer, w, nil, nil), nil return newWebDavFile(baseTransfer, w, nil), nil
} }
type objectMapping struct { type objectMapping struct {

View file

@ -18,6 +18,7 @@ import (
"github.com/eikenb/pipeat" "github.com/eikenb/pipeat"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"golang.org/x/net/webdav"
"github.com/drakkan/sftpgo/common" "github.com/drakkan/sftpgo/common"
"github.com/drakkan/sftpgo/dataprovider" "github.com/drakkan/sftpgo/dataprovider"
@ -89,9 +90,9 @@ func (fs MockOsFs) Walk(root string, walkFn filepath.WalkFunc) error {
return fs.err return fs.err
} }
// GetMimeType implements vfs.MimeTyper // GetMimeType returns the content type
func (fs MockOsFs) GetMimeType(name string) (string, error) { func (fs MockOsFs) GetMimeType(name string) (string, error) {
return "application/octet-stream", nil return "application/custom-mime", nil
} }
func newMockOsFs(err error, atomicUpload bool, connectionID, rootDir string) vfs.Fs { func newMockOsFs(err error, atomicUpload bool, connectionID, rootDir string) vfs.Fs {
@ -319,13 +320,11 @@ func TestFileAccessErrors(t *testing.T) {
if assert.Error(t, err) { if assert.Error(t, err) {
assert.EqualError(t, err, os.ErrNotExist.Error()) assert.EqualError(t, err, os.ErrNotExist.Error())
} }
info := vfs.NewFileInfo(missingPath, true, 0, time.Now(), false) _, err = connection.getFile(fsMissingPath, missingPath)
_, err = connection.getFile(fsMissingPath, missingPath, info)
if assert.Error(t, err) { if assert.Error(t, err) {
assert.EqualError(t, err, os.ErrNotExist.Error()) assert.EqualError(t, err, os.ErrNotExist.Error())
} }
info = vfs.NewFileInfo(missingPath, false, 123, time.Now(), false) _, err = connection.getFile(fsMissingPath, missingPath)
_, err = connection.getFile(fsMissingPath, missingPath, info)
if assert.Error(t, err) { if assert.Error(t, err) {
assert.EqualError(t, err, os.ErrNotExist.Error()) assert.EqualError(t, err, os.ErrNotExist.Error())
} }
@ -434,20 +433,34 @@ func TestContentType(t *testing.T) {
fs = newMockOsFs(nil, false, fs.ConnectionID(), user.GetHomeDir()) fs = newMockOsFs(nil, false, fs.ConnectionID(), user.GetHomeDir())
err := ioutil.WriteFile(testFilePath, []byte(""), os.ModePerm) err := ioutil.WriteFile(testFilePath, []byte(""), os.ModePerm)
assert.NoError(t, err) assert.NoError(t, err)
fi, err := os.Stat(testFilePath) davFile := newWebDavFile(baseTransfer, nil, nil)
assert.NoError(t, err)
davFile := newWebDavFile(baseTransfer, nil, nil, fi)
davFile.Fs = fs davFile.Fs = fs
fi, err = davFile.Stat() fi, err := davFile.Stat()
if assert.NoError(t, err) { if assert.NoError(t, err) {
ctype, err := fi.(webDavFileInfo).ContentType(ctx) ctype, err := fi.(*webDavFileInfo).ContentType(ctx)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, "application/octet-stream", ctype) assert.Equal(t, "application/custom-mime", ctype)
} }
_, err = davFile.Readdir(-1) _, err = davFile.Readdir(-1)
assert.Error(t, err) assert.Error(t, err)
err = davFile.Close() err = davFile.Close()
assert.NoError(t, err) assert.NoError(t, err)
davFile = newWebDavFile(baseTransfer, nil, nil)
davFile.Fs = vfs.NewOsFs("id", user.HomeDir, nil)
fi, err = davFile.Stat()
if assert.NoError(t, err) {
ctype, err := fi.(*webDavFileInfo).ContentType(ctx)
assert.NoError(t, err)
assert.Equal(t, "text/plain; charset=utf-8", ctype)
}
err = davFile.Close()
assert.NoError(t, err)
fi.(*webDavFileInfo).fsPath = "missing"
_, err = fi.(*webDavFileInfo).ContentType(ctx)
assert.EqualError(t, err, webdav.ErrNotImplemented.Error())
err = os.Remove(testFilePath) err = os.Remove(testFilePath)
assert.NoError(t, err) assert.NoError(t, err)
} }
@ -465,17 +478,16 @@ func TestTransferReadWriteErrors(t *testing.T) {
testFilePath := filepath.Join(user.HomeDir, testFile) testFilePath := filepath.Join(user.HomeDir, testFile)
baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFile, baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFile,
common.TransferUpload, 0, 0, 0, false, fs) common.TransferUpload, 0, 0, 0, false, fs)
davFile := newWebDavFile(baseTransfer, nil, nil, nil) davFile := newWebDavFile(baseTransfer, nil, nil)
assert.False(t, davFile.isDir())
p := make([]byte, 1) p := make([]byte, 1)
_, err := davFile.Read(p) _, err := davFile.Read(p)
assert.EqualError(t, err, common.ErrOpUnsupported.Error()) assert.EqualError(t, err, common.ErrOpUnsupported.Error())
r, w, err := pipeat.Pipe() r, w, err := pipeat.Pipe()
assert.NoError(t, err) assert.NoError(t, err)
davFile = newWebDavFile(baseTransfer, nil, r, nil) davFile = newWebDavFile(baseTransfer, nil, r)
davFile.Connection.RemoveTransfer(davFile.BaseTransfer) davFile.Connection.RemoveTransfer(davFile.BaseTransfer)
davFile = newWebDavFile(baseTransfer, vfs.NewPipeWriter(w), nil, nil) davFile = newWebDavFile(baseTransfer, vfs.NewPipeWriter(w), nil)
davFile.Connection.RemoveTransfer(davFile.BaseTransfer) davFile.Connection.RemoveTransfer(davFile.BaseTransfer)
err = r.Close() err = r.Close()
assert.NoError(t, err) assert.NoError(t, err)
@ -484,7 +496,7 @@ func TestTransferReadWriteErrors(t *testing.T) {
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFile, baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFile,
common.TransferDownload, 0, 0, 0, false, fs) common.TransferDownload, 0, 0, 0, false, fs)
davFile = newWebDavFile(baseTransfer, nil, nil, nil) davFile = newWebDavFile(baseTransfer, nil, nil)
_, err = davFile.Read(p) _, err = davFile.Read(p)
assert.True(t, os.IsNotExist(err)) assert.True(t, os.IsNotExist(err))
_, err = davFile.Stat() _, err = davFile.Stat()
@ -499,7 +511,7 @@ func TestTransferReadWriteErrors(t *testing.T) {
err = f.Close() err = f.Close()
assert.NoError(t, err) assert.NoError(t, err)
} }
davFile = newWebDavFile(baseTransfer, nil, nil, nil) davFile = newWebDavFile(baseTransfer, nil, nil)
davFile.reader = f davFile.reader = f
err = davFile.Close() err = davFile.Close()
assert.EqualError(t, err, common.ErrGenericFailure.Error()) assert.EqualError(t, err, common.ErrGenericFailure.Error())
@ -514,7 +526,7 @@ func TestTransferReadWriteErrors(t *testing.T) {
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFile, baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFile,
common.TransferDownload, 0, 0, 0, false, fs) common.TransferDownload, 0, 0, 0, false, fs)
davFile = newWebDavFile(baseTransfer, nil, nil, nil) davFile = newWebDavFile(baseTransfer, nil, nil)
davFile.writer = f davFile.writer = f
err = davFile.Close() err = davFile.Close()
assert.EqualError(t, err, common.ErrGenericFailure.Error()) assert.EqualError(t, err, common.ErrGenericFailure.Error())
@ -534,9 +546,10 @@ func TestTransferSeek(t *testing.T) {
BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, user, fs), BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, user, fs),
} }
testFilePath := filepath.Join(user.HomeDir, testFile) testFilePath := filepath.Join(user.HomeDir, testFile)
testFileContents := []byte("content")
baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFile, baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFile,
common.TransferUpload, 0, 0, 0, false, fs) common.TransferUpload, 0, 0, 0, false, fs)
davFile := newWebDavFile(baseTransfer, nil, nil, nil) davFile := newWebDavFile(baseTransfer, nil, nil)
_, err := davFile.Seek(0, io.SeekStart) _, err := davFile.Seek(0, io.SeekStart)
assert.EqualError(t, err, common.ErrOpUnsupported.Error()) assert.EqualError(t, err, common.ErrOpUnsupported.Error())
err = davFile.Close() err = davFile.Close()
@ -544,12 +557,12 @@ func TestTransferSeek(t *testing.T) {
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFile, baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFile,
common.TransferDownload, 0, 0, 0, false, fs) common.TransferDownload, 0, 0, 0, false, fs)
davFile = newWebDavFile(baseTransfer, nil, nil, nil) davFile = newWebDavFile(baseTransfer, nil, nil)
_, err = davFile.Seek(0, io.SeekCurrent) _, err = davFile.Seek(0, io.SeekCurrent)
assert.True(t, os.IsNotExist(err)) assert.True(t, os.IsNotExist(err))
davFile.Connection.RemoveTransfer(davFile.BaseTransfer) davFile.Connection.RemoveTransfer(davFile.BaseTransfer)
err = ioutil.WriteFile(testFilePath, []byte("content"), os.ModePerm) err = ioutil.WriteFile(testFilePath, testFileContents, os.ModePerm)
assert.NoError(t, err) assert.NoError(t, err)
f, err := os.Open(testFilePath) f, err := os.Open(testFilePath)
if assert.NoError(t, err) { if assert.NoError(t, err) {
@ -558,44 +571,55 @@ func TestTransferSeek(t *testing.T) {
} }
baseTransfer = common.NewBaseTransfer(f, connection.BaseConnection, nil, testFilePath, testFile, baseTransfer = common.NewBaseTransfer(f, connection.BaseConnection, nil, testFilePath, testFile,
common.TransferDownload, 0, 0, 0, false, fs) common.TransferDownload, 0, 0, 0, false, fs)
davFile = newWebDavFile(baseTransfer, nil, nil, nil) davFile = newWebDavFile(baseTransfer, nil, nil)
_, err = davFile.Seek(0, io.SeekStart) _, err = davFile.Seek(0, io.SeekStart)
assert.Error(t, err) assert.Error(t, err)
davFile.Connection.RemoveTransfer(davFile.BaseTransfer) davFile.Connection.RemoveTransfer(davFile.BaseTransfer)
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFile, baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFile,
common.TransferDownload, 0, 0, 0, false, fs) common.TransferDownload, 0, 0, 0, false, fs)
davFile = newWebDavFile(baseTransfer, nil, nil, nil) davFile = newWebDavFile(baseTransfer, nil, nil)
davFile.reader = f
res, err := davFile.Seek(0, io.SeekStart) res, err := davFile.Seek(0, io.SeekStart)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, int64(0), res) assert.Equal(t, int64(0), res)
davFile.Connection.RemoveTransfer(davFile.BaseTransfer) davFile.Connection.RemoveTransfer(davFile.BaseTransfer)
info, err := os.Stat(testFilePath) davFile = newWebDavFile(baseTransfer, nil, nil)
assert.NoError(t, err)
davFile = newWebDavFile(baseTransfer, nil, nil, info)
davFile.reader = f
res, err = davFile.Seek(0, io.SeekEnd) res, err = davFile.Seek(0, io.SeekEnd)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, int64(7), res) assert.Equal(t, int64(len(testFileContents)), res)
err = davFile.updateStatInfo()
assert.Nil(t, err)
davFile = newWebDavFile(baseTransfer, nil, nil, info) baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath+"1", testFile,
common.TransferDownload, 0, 0, 0, false, fs)
davFile = newWebDavFile(baseTransfer, nil, nil)
_, err = davFile.Seek(0, io.SeekEnd)
assert.True(t, os.IsNotExist(err))
davFile.Connection.RemoveTransfer(davFile.BaseTransfer)
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFile,
common.TransferDownload, 0, 0, 0, false, fs)
davFile = newWebDavFile(baseTransfer, nil, nil)
davFile.reader = f davFile.reader = f
davFile.Fs = newMockOsFs(nil, true, fs.ConnectionID(), user.GetHomeDir()) davFile.Fs = newMockOsFs(nil, true, fs.ConnectionID(), user.GetHomeDir())
res, err = davFile.Seek(2, io.SeekStart) res, err = davFile.Seek(2, io.SeekStart)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, int64(2), res) assert.Equal(t, int64(2), res)
davFile = newWebDavFile(baseTransfer, nil, nil, info) davFile = newWebDavFile(baseTransfer, nil, nil)
davFile.Fs = newMockOsFs(nil, true, fs.ConnectionID(), user.GetHomeDir()) davFile.Fs = newMockOsFs(nil, true, fs.ConnectionID(), user.GetHomeDir())
res, err = davFile.Seek(2, io.SeekEnd) res, err = davFile.Seek(2, io.SeekEnd)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, int64(5), res) assert.Equal(t, int64(5), res)
davFile = newWebDavFile(baseTransfer, nil, nil, nil) baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath+"1", testFile,
common.TransferDownload, 0, 0, 0, false, fs)
davFile = newWebDavFile(baseTransfer, nil, nil)
davFile.Fs = newMockOsFs(nil, true, fs.ConnectionID(), user.GetHomeDir())
res, err = davFile.Seek(2, io.SeekEnd) res, err = davFile.Seek(2, io.SeekEnd)
assert.EqualError(t, err, "unable to get file size, seek from end not possible") assert.True(t, os.IsNotExist(err))
assert.Equal(t, int64(0), res) assert.Equal(t, int64(0), res)
assert.Len(t, common.Connections.GetStats(), 0) assert.Len(t, common.Connections.GetStats(), 0)
@ -622,10 +646,12 @@ func TestBasicUsersCache(t *testing.T) {
c := &Configuration{ c := &Configuration{
BindPort: 9000, BindPort: 9000,
Cache: Cache{ Cache: Cache{
Users: UsersCacheConfig{
Enabled: true, Enabled: true,
MaxSize: 50, MaxSize: 50,
ExpirationTime: 1, ExpirationTime: 1,
}, },
},
} }
server, err := newServer(c, configDir) server, err := newServer(c, configDir)
assert.NoError(t, err) assert.NoError(t, err)
@ -633,48 +659,48 @@ func TestBasicUsersCache(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user.Username), nil) req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user.Username), nil)
assert.NoError(t, err) assert.NoError(t, err)
_, _, err = server.authenticate(req) _, _, _, err = server.authenticate(req) //nolint:dogsled
assert.Error(t, err) assert.Error(t, err)
now := time.Now() now := time.Now()
req.SetBasicAuth(username, password) req.SetBasicAuth(username, password)
_, isCached, err := server.authenticate(req) _, isCached, _, err := server.authenticate(req)
assert.NoError(t, err) assert.NoError(t, err)
assert.False(t, isCached) assert.False(t, isCached)
// now the user should be cached // now the user should be cached
var cachedUser dataprovider.CachedUser var cachedUser *dataprovider.CachedUser
result, ok := dataprovider.GetCachedWebDAVUser(username) result, ok := dataprovider.GetCachedWebDAVUser(username)
if assert.True(t, ok) { if assert.True(t, ok) {
cachedUser = result.(dataprovider.CachedUser) cachedUser = result.(*dataprovider.CachedUser)
assert.False(t, cachedUser.IsExpired()) assert.False(t, cachedUser.IsExpired())
assert.True(t, cachedUser.Expiration.After(now.Add(time.Duration(c.Cache.ExpirationTime)*time.Minute))) assert.True(t, cachedUser.Expiration.After(now.Add(time.Duration(c.Cache.Users.ExpirationTime)*time.Minute)))
// authenticate must return the cached user now // authenticate must return the cached user now
authUser, isCached, err := server.authenticate(req) authUser, isCached, _, err := server.authenticate(req)
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, isCached) assert.True(t, isCached)
assert.Equal(t, cachedUser.User, authUser) assert.Equal(t, cachedUser.User, authUser)
} }
// a wrong password must fail // a wrong password must fail
req.SetBasicAuth(username, "wrong") req.SetBasicAuth(username, "wrong")
_, _, err = server.authenticate(req) _, _, _, err = server.authenticate(req) //nolint:dogsled
assert.EqualError(t, err, dataprovider.ErrInvalidCredentials.Error()) assert.EqualError(t, err, dataprovider.ErrInvalidCredentials.Error())
req.SetBasicAuth(username, password) req.SetBasicAuth(username, password)
// force cached user expiration // force cached user expiration
cachedUser.Expiration = now cachedUser.Expiration = now
dataprovider.CacheWebDAVUser(cachedUser, c.Cache.MaxSize) dataprovider.CacheWebDAVUser(cachedUser, c.Cache.Users.MaxSize)
result, ok = dataprovider.GetCachedWebDAVUser(username) result, ok = dataprovider.GetCachedWebDAVUser(username)
if assert.True(t, ok) { if assert.True(t, ok) {
cachedUser = result.(dataprovider.CachedUser) cachedUser = result.(*dataprovider.CachedUser)
assert.True(t, cachedUser.IsExpired()) assert.True(t, cachedUser.IsExpired())
} }
// now authenticate should get the user from the data provider and update the cache // now authenticate should get the user from the data provider and update the cache
_, isCached, err = server.authenticate(req) _, isCached, _, err = server.authenticate(req)
assert.NoError(t, err) assert.NoError(t, err)
assert.False(t, isCached) assert.False(t, isCached)
result, ok = dataprovider.GetCachedWebDAVUser(username) result, ok = dataprovider.GetCachedWebDAVUser(username)
if assert.True(t, ok) { if assert.True(t, ok) {
cachedUser = result.(dataprovider.CachedUser) cachedUser = result.(*dataprovider.CachedUser)
assert.False(t, cachedUser.IsExpired()) assert.False(t, cachedUser.IsExpired())
} }
// cache is invalidated after a user modification // cache is invalidated after a user modification
@ -683,7 +709,7 @@ func TestBasicUsersCache(t *testing.T) {
_, ok = dataprovider.GetCachedWebDAVUser(username) _, ok = dataprovider.GetCachedWebDAVUser(username)
assert.False(t, ok) assert.False(t, ok)
_, isCached, err = server.authenticate(req) _, isCached, _, err = server.authenticate(req)
assert.NoError(t, err) assert.NoError(t, err)
assert.False(t, isCached) assert.False(t, isCached)
_, ok = dataprovider.GetCachedWebDAVUser(username) _, ok = dataprovider.GetCachedWebDAVUser(username)
@ -725,10 +751,12 @@ func TestUsersCacheSizeAndExpiration(t *testing.T) {
c := &Configuration{ c := &Configuration{
BindPort: 9000, BindPort: 9000,
Cache: Cache{ Cache: Cache{
Users: UsersCacheConfig{
Enabled: true, Enabled: true,
MaxSize: 3, MaxSize: 3,
ExpirationTime: 1, ExpirationTime: 1,
}, },
},
} }
server, err := newServer(c, configDir) server, err := newServer(c, configDir)
assert.NoError(t, err) assert.NoError(t, err)
@ -736,21 +764,21 @@ func TestUsersCacheSizeAndExpiration(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user1.Username), nil) req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user1.Username), nil)
assert.NoError(t, err) assert.NoError(t, err)
req.SetBasicAuth(user1.Username, password+"1") req.SetBasicAuth(user1.Username, password+"1")
_, isCached, err := server.authenticate(req) _, isCached, _, err := server.authenticate(req)
assert.NoError(t, err) assert.NoError(t, err)
assert.False(t, isCached) assert.False(t, isCached)
req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user2.Username), nil) req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user2.Username), nil)
assert.NoError(t, err) assert.NoError(t, err)
req.SetBasicAuth(user2.Username, password+"2") req.SetBasicAuth(user2.Username, password+"2")
_, isCached, err = server.authenticate(req) _, isCached, _, err = server.authenticate(req)
assert.NoError(t, err) assert.NoError(t, err)
assert.False(t, isCached) assert.False(t, isCached)
req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user3.Username), nil) req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user3.Username), nil)
assert.NoError(t, err) assert.NoError(t, err)
req.SetBasicAuth(user3.Username, password+"3") req.SetBasicAuth(user3.Username, password+"3")
_, isCached, err = server.authenticate(req) _, isCached, _, err = server.authenticate(req)
assert.NoError(t, err) assert.NoError(t, err)
assert.False(t, isCached) assert.False(t, isCached)
@ -765,7 +793,7 @@ func TestUsersCacheSizeAndExpiration(t *testing.T) {
req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user4.Username), nil) req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user4.Username), nil)
assert.NoError(t, err) assert.NoError(t, err)
req.SetBasicAuth(user4.Username, password+"4") req.SetBasicAuth(user4.Username, password+"4")
_, isCached, err = server.authenticate(req) _, isCached, _, err = server.authenticate(req)
assert.NoError(t, err) assert.NoError(t, err)
assert.False(t, isCached) assert.False(t, isCached)
// user1, the first cached, should be removed now // user1, the first cached, should be removed now
@ -782,7 +810,7 @@ func TestUsersCacheSizeAndExpiration(t *testing.T) {
req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user1.Username), nil) req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user1.Username), nil)
assert.NoError(t, err) assert.NoError(t, err)
req.SetBasicAuth(user1.Username, password+"1") req.SetBasicAuth(user1.Username, password+"1")
_, isCached, err = server.authenticate(req) _, isCached, _, err = server.authenticate(req)
assert.NoError(t, err) assert.NoError(t, err)
assert.False(t, isCached) assert.False(t, isCached)
_, ok = dataprovider.GetCachedWebDAVUser(user2.Username) _, ok = dataprovider.GetCachedWebDAVUser(user2.Username)
@ -798,7 +826,7 @@ func TestUsersCacheSizeAndExpiration(t *testing.T) {
req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user2.Username), nil) req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user2.Username), nil)
assert.NoError(t, err) assert.NoError(t, err)
req.SetBasicAuth(user2.Username, password+"2") req.SetBasicAuth(user2.Username, password+"2")
_, isCached, err = server.authenticate(req) _, isCached, _, err = server.authenticate(req)
assert.NoError(t, err) assert.NoError(t, err)
assert.False(t, isCached) assert.False(t, isCached)
_, ok = dataprovider.GetCachedWebDAVUser(user3.Username) _, ok = dataprovider.GetCachedWebDAVUser(user3.Username)
@ -814,7 +842,7 @@ func TestUsersCacheSizeAndExpiration(t *testing.T) {
req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user3.Username), nil) req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user3.Username), nil)
assert.NoError(t, err) assert.NoError(t, err)
req.SetBasicAuth(user3.Username, password+"3") req.SetBasicAuth(user3.Username, password+"3")
_, isCached, err = server.authenticate(req) _, isCached, _, err = server.authenticate(req)
assert.NoError(t, err) assert.NoError(t, err)
assert.False(t, isCached) assert.False(t, isCached)
_, ok = dataprovider.GetCachedWebDAVUser(user4.Username) _, ok = dataprovider.GetCachedWebDAVUser(user4.Username)
@ -835,14 +863,14 @@ func TestUsersCacheSizeAndExpiration(t *testing.T) {
req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user4.Username), nil) req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user4.Username), nil)
assert.NoError(t, err) assert.NoError(t, err)
req.SetBasicAuth(user4.Username, password+"4") req.SetBasicAuth(user4.Username, password+"4")
_, isCached, err = server.authenticate(req) _, isCached, _, err = server.authenticate(req)
assert.NoError(t, err) assert.NoError(t, err)
assert.False(t, isCached) assert.False(t, isCached)
req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user1.Username), nil) req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user1.Username), nil)
assert.NoError(t, err) assert.NoError(t, err)
req.SetBasicAuth(user1.Username, password+"1") req.SetBasicAuth(user1.Username, password+"1")
_, isCached, err = server.authenticate(req) _, isCached, _, err = server.authenticate(req)
assert.NoError(t, err) assert.NoError(t, err)
assert.False(t, isCached) assert.False(t, isCached)
_, ok = dataprovider.GetCachedWebDAVUser(user2.Username) _, ok = dataprovider.GetCachedWebDAVUser(user2.Username)
@ -874,3 +902,20 @@ func TestRecoverer(t *testing.T) {
server.ServeHTTP(rr, nil) server.ServeHTTP(rr, nil)
assert.Equal(t, http.StatusInternalServerError, rr.Code) assert.Equal(t, http.StatusInternalServerError, rr.Code)
} }
func TestMimeCache(t *testing.T) {
cache := mimeCache{
maxSize: 0,
mimeTypes: make(map[string]string),
}
cache.addMimeToCache(".zip", "application/zip")
mtype := cache.getMimeFromCache(".zip")
assert.Equal(t, "", mtype)
cache.maxSize = 1
cache.addMimeToCache(".zip", "application/zip")
mtype = cache.getMimeFromCache(".zip")
assert.Equal(t, "application/zip", mtype)
cache.addMimeToCache(".jpg", "image/jpeg")
mtype = cache.getMimeFromCache(".jpg")
assert.Equal(t, "", mtype)
}

35
webdavd/mimecache.go Normal file
View file

@ -0,0 +1,35 @@
package webdavd
import "sync"
type mimeCache struct {
maxSize int
sync.RWMutex
mimeTypes map[string]string
}
var mimeTypeCache mimeCache
func (c *mimeCache) addMimeToCache(key, value string) {
c.Lock()
defer c.Unlock()
if key == "" || value == "" {
return
}
if len(c.mimeTypes) >= c.maxSize {
return
}
c.mimeTypes[key] = value
}
func (c *mimeCache) getMimeFromCache(key string) string {
c.RLock()
defer c.RUnlock()
if val, ok := c.mimeTypes[key]; ok {
return val
}
return ""
}

View file

@ -97,13 +97,18 @@ func (s *webDavServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
http.Error(w, common.ErrConnectionDenied.Error(), http.StatusForbidden) http.Error(w, common.ErrConnectionDenied.Error(), http.StatusForbidden)
return return
} }
user, isCached, err := s.authenticate(r) user, isCached, lockSystem, err := s.authenticate(r)
if err != nil { if err != nil {
w.Header().Set("WWW-Authenticate", "Basic realm=\"SFTPGo WebDAV\"") w.Header().Set("WWW-Authenticate", "Basic realm=\"SFTPGo WebDAV\"")
http.Error(w, err401.Error(), http.StatusUnauthorized) http.Error(w, err401.Error(), http.StatusUnauthorized)
return return
} }
if path.Clean(r.URL.Path) == "/" && (r.Method == "GET" || r.Method == "PROPFIND" || r.Method == "OPTIONS") {
http.Redirect(w, r, path.Join("/", user.Username), http.StatusMovedPermanently)
return
}
connectionID, err := s.validateUser(user, r) connectionID, err := s.validateUser(user, r)
if err != nil { if err != nil {
updateLoginMetrics(user.Username, r.RemoteAddr, err) updateLoginMetrics(user.Username, r.RemoteAddr, err)
@ -152,49 +157,52 @@ func (s *webDavServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
handler := webdav.Handler{ handler := webdav.Handler{
Prefix: prefix, Prefix: prefix,
FileSystem: connection, FileSystem: connection,
LockSystem: webdav.NewMemLS(), LockSystem: lockSystem,
Logger: writeLog, Logger: writeLog,
} }
handler.ServeHTTP(w, r.WithContext(ctx)) handler.ServeHTTP(w, r.WithContext(ctx))
} }
func (s *webDavServer) authenticate(r *http.Request) (dataprovider.User, bool, error) { func (s *webDavServer) authenticate(r *http.Request) (dataprovider.User, bool, webdav.LockSystem, error) {
var user dataprovider.User var user dataprovider.User
var err error var err error
username, password, ok := r.BasicAuth() username, password, ok := r.BasicAuth()
if !ok { if !ok {
return user, false, err401 return user, false, nil, err401
} }
if s.config.Cache.Enabled { if s.config.Cache.Users.Enabled {
result, ok := dataprovider.GetCachedWebDAVUser(username) result, ok := dataprovider.GetCachedWebDAVUser(username)
if ok { if ok {
if result.(dataprovider.CachedUser).IsExpired() { cachedUser := result.(*dataprovider.CachedUser)
if cachedUser.IsExpired() {
dataprovider.RemoveCachedWebDAVUser(username) dataprovider.RemoveCachedWebDAVUser(username)
} else { } else {
if len(password) > 0 && result.(dataprovider.CachedUser).Password == password { if len(password) > 0 && cachedUser.Password == password {
return result.(dataprovider.CachedUser).User, true, nil return cachedUser.User, true, cachedUser.LockSystem, nil
} }
updateLoginMetrics(username, r.RemoteAddr, dataprovider.ErrInvalidCredentials) updateLoginMetrics(username, r.RemoteAddr, dataprovider.ErrInvalidCredentials)
return user, false, dataprovider.ErrInvalidCredentials return user, false, nil, dataprovider.ErrInvalidCredentials
} }
} }
} }
user, err = dataprovider.CheckUserAndPass(username, password, utils.GetIPFromRemoteAddress(r.RemoteAddr), common.ProtocolWebDAV) user, err = dataprovider.CheckUserAndPass(username, password, utils.GetIPFromRemoteAddress(r.RemoteAddr), common.ProtocolWebDAV)
if err != nil { if err != nil {
updateLoginMetrics(username, r.RemoteAddr, err) updateLoginMetrics(username, r.RemoteAddr, err)
return user, false, err return user, false, nil, err
} }
if s.config.Cache.Enabled && len(password) > 0 { lockSystem := webdav.NewMemLS()
cachedUser := dataprovider.CachedUser{ if s.config.Cache.Users.Enabled && len(password) > 0 {
cachedUser := &dataprovider.CachedUser{
User: user, User: user,
Password: password, Password: password,
LockSystem: lockSystem,
} }
if s.config.Cache.ExpirationTime > 0 { if s.config.Cache.Users.ExpirationTime > 0 {
cachedUser.Expiration = time.Now().Add(time.Duration(s.config.Cache.ExpirationTime) * time.Minute) cachedUser.Expiration = time.Now().Add(time.Duration(s.config.Cache.Users.ExpirationTime) * time.Minute)
} }
dataprovider.CacheWebDAVUser(cachedUser, s.config.Cache.MaxSize) dataprovider.CacheWebDAVUser(cachedUser, s.config.Cache.Users.MaxSize)
} }
return user, false, err return user, false, lockSystem, err
} }
func (s *webDavServer) validateUser(user dataprovider.User, r *http.Request) (string, error) { func (s *webDavServer) validateUser(user dataprovider.User, r *http.Request) (string, error) {

View file

@ -34,13 +34,25 @@ type Cors struct {
MaxAge int `json:"max_age" mapstructure:"max_age"` MaxAge int `json:"max_age" mapstructure:"max_age"`
} }
// Cache configuration // UsersCacheConfig defines the cache configuration for users
type Cache struct { type UsersCacheConfig struct {
Enabled bool `json:"enabled" mapstructure:"enabled"` Enabled bool `json:"enabled" mapstructure:"enabled"`
ExpirationTime int `json:"expiration_time" mapstructure:"expiration_time"` ExpirationTime int `json:"expiration_time" mapstructure:"expiration_time"`
MaxSize int `json:"max_size" mapstructure:"max_size"` MaxSize int `json:"max_size" mapstructure:"max_size"`
} }
// MimeCacheConfig defines the cache configuration for mime types
type MimeCacheConfig struct {
Enabled bool `json:"enabled" mapstructure:"enabled"`
MaxSize int `json:"max_size" mapstructure:"max_size"`
}
// Cache configuration
type Cache struct {
Users UsersCacheConfig `json:"users" mapstructure:"users"`
MimeTypes MimeCacheConfig `json:"mime_types" mapstructure:"mime_types"`
}
// Configuration defines the configuration for the WevDAV server // Configuration defines the configuration for the WevDAV server
type Configuration struct { type Configuration struct {
// The port used for serving FTP requests // The port used for serving FTP requests
@ -63,6 +75,13 @@ type Configuration struct {
func (c *Configuration) Initialize(configDir string) error { func (c *Configuration) Initialize(configDir string) error {
var err error var err error
logger.Debug(logSender, "", "initializing WevDav server with config %+v", *c) logger.Debug(logSender, "", "initializing WevDav server with config %+v", *c)
mimeTypeCache = mimeCache{
maxSize: c.Cache.MimeTypes.MaxSize,
mimeTypes: make(map[string]string),
}
if !c.Cache.MimeTypes.Enabled {
mimeTypeCache.maxSize = 0
}
server, err = newServer(c, configDir) server, err = newServer(c, configDir)
if err != nil { if err != nil {
return err return err

View file

@ -3,6 +3,7 @@ package webdavd_test
import ( import (
"crypto/rand" "crypto/rand"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
@ -180,18 +181,19 @@ func TestMain(m *testing.M) {
} }
func TestInitialization(t *testing.T) { func TestInitialization(t *testing.T) {
config := webdavd.Configuration{ cfg := webdavd.Configuration{
BindPort: 1234, BindPort: 1234,
CertificateFile: "missing path", CertificateFile: "missing path",
CertificateKeyFile: "bad path", CertificateKeyFile: "bad path",
} }
err := config.Initialize(configDir) err := cfg.Initialize(configDir)
assert.Error(t, err) assert.Error(t, err)
config.BindPort = webDavServerPort cfg.Cache = config.GetWebDAVDConfig().Cache
config.CertificateFile = certPath cfg.BindPort = webDavServerPort
config.CertificateKeyFile = keyPath cfg.CertificateFile = certPath
err = config.Initialize(configDir) cfg.CertificateKeyFile = keyPath
err = cfg.Initialize(configDir)
assert.Error(t, err) assert.Error(t, err)
err = webdavd.ReloadTLSCertificate() err = webdavd.ReloadTLSCertificate()
assert.NoError(t, err) assert.NoError(t, err)
@ -253,9 +255,11 @@ func TestBasicHandling(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
err = uploadFile(testFilePath, path.Join(testDir, testFileName+".txt"), testFileSize, client) err = uploadFile(testFilePath, path.Join(testDir, testFileName+".txt"), testFileSize, client)
assert.NoError(t, err) assert.NoError(t, err)
err = uploadFile(testFilePath, path.Join(testDir, testFileName), testFileSize, client)
assert.NoError(t, err)
files, err := client.ReadDir(testDir) files, err := client.ReadDir(testDir)
assert.NoError(t, err) assert.NoError(t, err)
assert.Len(t, files, 4) assert.Len(t, files, 5)
err = client.Copy(testDir, testDir+"_copy", false) err = client.Copy(testDir, testDir+"_copy", false)
assert.NoError(t, err) assert.NoError(t, err)
err = client.RemoveAll(testDir) err = client.RemoveAll(testDir)
@ -303,6 +307,50 @@ func TestLoginInvalidURL(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
} }
func TestRootRedirect(t *testing.T) {
errRedirect := errors.New("redirect error")
u := getTestUser()
user, _, err := httpd.AddUser(u, http.StatusOK)
assert.NoError(t, err)
client := getWebDavClient(user)
assert.NoError(t, checkBasicFunc(client))
rootPath := fmt.Sprintf("http://%v/", webDavServerAddr)
httpClient := httpclient.GetHTTPClient()
httpClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
return errRedirect
}
req, err := http.NewRequest(http.MethodOptions, rootPath, nil)
assert.NoError(t, err)
req.SetBasicAuth(u.Username, u.Password)
resp, err := httpClient.Do(req)
if assert.Error(t, err) {
assert.Contains(t, err.Error(), errRedirect.Error())
}
err = resp.Body.Close()
assert.NoError(t, err)
req, err = http.NewRequest(http.MethodGet, rootPath, nil)
assert.NoError(t, err)
req.SetBasicAuth(u.Username, u.Password)
resp, err = httpClient.Do(req)
if assert.Error(t, err) {
assert.Contains(t, err.Error(), errRedirect.Error())
}
err = resp.Body.Close()
assert.NoError(t, err)
req, err = http.NewRequest("PROPFIND", rootPath, nil)
assert.NoError(t, err)
req.SetBasicAuth(u.Username, u.Password)
resp, err = httpClient.Do(req)
if assert.Error(t, err) {
assert.Contains(t, err.Error(), errRedirect.Error())
}
err = resp.Body.Close()
assert.NoError(t, err)
_, err = httpd.RemoveUser(user, http.StatusOK)
assert.NoError(t, err)
}
func TestLoginExternalAuth(t *testing.T) { func TestLoginExternalAuth(t *testing.T) {
if runtime.GOOS == osWindows { if runtime.GOOS == osWindows {
t.Skip("this test is not available on Windows") t.Skip("this test is not available on Windows")
@ -640,6 +688,9 @@ func TestQuotaLimits(t *testing.T) {
assert.Error(t, err) assert.Error(t, err)
err = client.Rename(testFileName+".quota", testFileName, false) err = client.Rename(testFileName+".quota", testFileName, false)
assert.NoError(t, err) assert.NoError(t, err)
files, err := client.ReadDir("/")
assert.NoError(t, err)
assert.Len(t, files, 1)
// test quota size // test quota size
user.QuotaSize = testFileSize - 1 user.QuotaSize = testFileSize - 1
user.QuotaFiles = 0 user.QuotaFiles = 0
@ -929,7 +980,7 @@ func TestGETAsPROPFIND(t *testing.T) {
} }
} }
client := getWebDavClient(user) client := getWebDavClient(user)
err = client.MkdirAll(path.Join(subDir1, "sub"), os.ModePerm) err = client.MkdirAll(path.Join(subDir1, "sub", "sub1"), os.ModePerm)
assert.NoError(t, err) assert.NoError(t, err)
subPath := fmt.Sprintf("http://%v/%v", webDavServerAddr, path.Join(user.Username, subDir1)) subPath := fmt.Sprintf("http://%v/%v", webDavServerAddr, path.Join(user.Username, subDir1))
req, err = http.NewRequest(http.MethodGet, subPath, nil) req, err = http.NewRequest(http.MethodGet, subPath, nil)
@ -937,10 +988,36 @@ func TestGETAsPROPFIND(t *testing.T) {
req.SetBasicAuth(u.Username, u.Password) req.SetBasicAuth(u.Username, u.Password)
resp, err := httpClient.Do(req) resp, err := httpClient.Do(req)
if assert.NoError(t, err) { if assert.NoError(t, err) {
assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) // before the performance patch we have a 500 here, now we have 207 but an empty list
//assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
assert.Equal(t, http.StatusMultiStatus, resp.StatusCode)
resp.Body.Close() resp.Body.Close()
} }
} }
// we cannot stat the sub at all
subPath1 := fmt.Sprintf("http://%v/%v", webDavServerAddr, path.Join(user.Username, subDir1, "sub"))
req, err = http.NewRequest(http.MethodGet, subPath1, nil)
if assert.NoError(t, err) {
req.SetBasicAuth(u.Username, u.Password)
resp, err := httpClient.Do(req)
if assert.NoError(t, err) {
// here the stat will fail, so the request will not be changed in propfind
assert.Equal(t, http.StatusForbidden, resp.StatusCode)
resp.Body.Close()
}
}
// we have no permission, we get an empty list
files, err := client.ReadDir(subDir1)
assert.NoError(t, err)
assert.Len(t, files, 0)
// if we grant the permissions the files are listed
user.Permissions[subDir1] = []string{dataprovider.PermDownload, dataprovider.PermListItems}
user, _, err = httpd.UpdateUser(user, http.StatusOK, "")
assert.NoError(t, err)
files, err = client.ReadDir(subDir1)
assert.NoError(t, err)
assert.Len(t, files, 1)
_, err = httpd.RemoveUser(user, http.StatusOK) _, err = httpd.RemoveUser(user, http.StatusOK)
assert.NoError(t, err) assert.NoError(t, err)