|
@@ -18,6 +18,7 @@ import (
|
|
|
|
|
|
"github.com/eikenb/pipeat"
|
|
"github.com/eikenb/pipeat"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/assert"
|
|
|
|
+ "golang.org/x/net/webdav"
|
|
|
|
|
|
"github.com/drakkan/sftpgo/common"
|
|
"github.com/drakkan/sftpgo/common"
|
|
"github.com/drakkan/sftpgo/dataprovider"
|
|
"github.com/drakkan/sftpgo/dataprovider"
|
|
@@ -89,9 +90,9 @@ func (fs MockOsFs) Walk(root string, walkFn filepath.WalkFunc) error {
|
|
return fs.err
|
|
return fs.err
|
|
}
|
|
}
|
|
|
|
|
|
-// GetMimeType implements vfs.MimeTyper
|
|
|
|
|
|
+// GetMimeType returns the content type
|
|
func (fs MockOsFs) GetMimeType(name string) (string, error) {
|
|
func (fs MockOsFs) GetMimeType(name string) (string, error) {
|
|
- return "application/octet-stream", nil
|
|
|
|
|
|
+ return "application/custom-mime", nil
|
|
}
|
|
}
|
|
|
|
|
|
func newMockOsFs(err error, atomicUpload bool, connectionID, rootDir string) vfs.Fs {
|
|
func newMockOsFs(err error, atomicUpload bool, connectionID, rootDir string) vfs.Fs {
|
|
@@ -319,13 +320,11 @@ func TestFileAccessErrors(t *testing.T) {
|
|
if assert.Error(t, err) {
|
|
if assert.Error(t, err) {
|
|
assert.EqualError(t, err, os.ErrNotExist.Error())
|
|
assert.EqualError(t, err, os.ErrNotExist.Error())
|
|
}
|
|
}
|
|
- info := vfs.NewFileInfo(missingPath, true, 0, time.Now(), false)
|
|
|
|
- _, err = connection.getFile(fsMissingPath, missingPath, info)
|
|
|
|
|
|
+ _, err = connection.getFile(fsMissingPath, missingPath)
|
|
if assert.Error(t, err) {
|
|
if assert.Error(t, err) {
|
|
assert.EqualError(t, err, os.ErrNotExist.Error())
|
|
assert.EqualError(t, err, os.ErrNotExist.Error())
|
|
}
|
|
}
|
|
- info = vfs.NewFileInfo(missingPath, false, 123, time.Now(), false)
|
|
|
|
- _, err = connection.getFile(fsMissingPath, missingPath, info)
|
|
|
|
|
|
+ _, err = connection.getFile(fsMissingPath, missingPath)
|
|
if assert.Error(t, err) {
|
|
if assert.Error(t, err) {
|
|
assert.EqualError(t, err, os.ErrNotExist.Error())
|
|
assert.EqualError(t, err, os.ErrNotExist.Error())
|
|
}
|
|
}
|
|
@@ -434,20 +433,34 @@ func TestContentType(t *testing.T) {
|
|
fs = newMockOsFs(nil, false, fs.ConnectionID(), user.GetHomeDir())
|
|
fs = newMockOsFs(nil, false, fs.ConnectionID(), user.GetHomeDir())
|
|
err := ioutil.WriteFile(testFilePath, []byte(""), os.ModePerm)
|
|
err := ioutil.WriteFile(testFilePath, []byte(""), os.ModePerm)
|
|
assert.NoError(t, err)
|
|
assert.NoError(t, err)
|
|
- fi, err := os.Stat(testFilePath)
|
|
|
|
- assert.NoError(t, err)
|
|
|
|
- davFile := newWebDavFile(baseTransfer, nil, nil, fi)
|
|
|
|
|
|
+ davFile := newWebDavFile(baseTransfer, nil, nil)
|
|
davFile.Fs = fs
|
|
davFile.Fs = fs
|
|
- fi, err = davFile.Stat()
|
|
|
|
|
|
+ fi, err := davFile.Stat()
|
|
if assert.NoError(t, err) {
|
|
if assert.NoError(t, err) {
|
|
- ctype, err := fi.(webDavFileInfo).ContentType(ctx)
|
|
|
|
|
|
+ ctype, err := fi.(*webDavFileInfo).ContentType(ctx)
|
|
assert.NoError(t, err)
|
|
assert.NoError(t, err)
|
|
- assert.Equal(t, "application/octet-stream", ctype)
|
|
|
|
|
|
+ assert.Equal(t, "application/custom-mime", ctype)
|
|
}
|
|
}
|
|
_, err = davFile.Readdir(-1)
|
|
_, err = davFile.Readdir(-1)
|
|
assert.Error(t, err)
|
|
assert.Error(t, err)
|
|
err = davFile.Close()
|
|
err = davFile.Close()
|
|
assert.NoError(t, err)
|
|
assert.NoError(t, err)
|
|
|
|
+
|
|
|
|
+ davFile = newWebDavFile(baseTransfer, nil, nil)
|
|
|
|
+ davFile.Fs = vfs.NewOsFs("id", user.HomeDir, nil)
|
|
|
|
+ fi, err = davFile.Stat()
|
|
|
|
+ if assert.NoError(t, err) {
|
|
|
|
+ ctype, err := fi.(*webDavFileInfo).ContentType(ctx)
|
|
|
|
+ assert.NoError(t, err)
|
|
|
|
+ assert.Equal(t, "text/plain; charset=utf-8", ctype)
|
|
|
|
+ }
|
|
|
|
+ err = davFile.Close()
|
|
|
|
+ assert.NoError(t, err)
|
|
|
|
+
|
|
|
|
+ fi.(*webDavFileInfo).fsPath = "missing"
|
|
|
|
+ _, err = fi.(*webDavFileInfo).ContentType(ctx)
|
|
|
|
+ assert.EqualError(t, err, webdav.ErrNotImplemented.Error())
|
|
|
|
+
|
|
err = os.Remove(testFilePath)
|
|
err = os.Remove(testFilePath)
|
|
assert.NoError(t, err)
|
|
assert.NoError(t, err)
|
|
}
|
|
}
|
|
@@ -465,17 +478,16 @@ func TestTransferReadWriteErrors(t *testing.T) {
|
|
testFilePath := filepath.Join(user.HomeDir, testFile)
|
|
testFilePath := filepath.Join(user.HomeDir, testFile)
|
|
baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFile,
|
|
baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFile,
|
|
common.TransferUpload, 0, 0, 0, false, fs)
|
|
common.TransferUpload, 0, 0, 0, false, fs)
|
|
- davFile := newWebDavFile(baseTransfer, nil, nil, nil)
|
|
|
|
- assert.False(t, davFile.isDir())
|
|
|
|
|
|
+ davFile := newWebDavFile(baseTransfer, nil, nil)
|
|
p := make([]byte, 1)
|
|
p := make([]byte, 1)
|
|
_, err := davFile.Read(p)
|
|
_, err := davFile.Read(p)
|
|
assert.EqualError(t, err, common.ErrOpUnsupported.Error())
|
|
assert.EqualError(t, err, common.ErrOpUnsupported.Error())
|
|
|
|
|
|
r, w, err := pipeat.Pipe()
|
|
r, w, err := pipeat.Pipe()
|
|
assert.NoError(t, err)
|
|
assert.NoError(t, err)
|
|
- davFile = newWebDavFile(baseTransfer, nil, r, nil)
|
|
|
|
|
|
+ davFile = newWebDavFile(baseTransfer, nil, r)
|
|
davFile.Connection.RemoveTransfer(davFile.BaseTransfer)
|
|
davFile.Connection.RemoveTransfer(davFile.BaseTransfer)
|
|
- davFile = newWebDavFile(baseTransfer, vfs.NewPipeWriter(w), nil, nil)
|
|
|
|
|
|
+ davFile = newWebDavFile(baseTransfer, vfs.NewPipeWriter(w), nil)
|
|
davFile.Connection.RemoveTransfer(davFile.BaseTransfer)
|
|
davFile.Connection.RemoveTransfer(davFile.BaseTransfer)
|
|
err = r.Close()
|
|
err = r.Close()
|
|
assert.NoError(t, err)
|
|
assert.NoError(t, err)
|
|
@@ -484,7 +496,7 @@ func TestTransferReadWriteErrors(t *testing.T) {
|
|
|
|
|
|
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFile,
|
|
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFile,
|
|
common.TransferDownload, 0, 0, 0, false, fs)
|
|
common.TransferDownload, 0, 0, 0, false, fs)
|
|
- davFile = newWebDavFile(baseTransfer, nil, nil, nil)
|
|
|
|
|
|
+ davFile = newWebDavFile(baseTransfer, nil, nil)
|
|
_, err = davFile.Read(p)
|
|
_, err = davFile.Read(p)
|
|
assert.True(t, os.IsNotExist(err))
|
|
assert.True(t, os.IsNotExist(err))
|
|
_, err = davFile.Stat()
|
|
_, err = davFile.Stat()
|
|
@@ -499,7 +511,7 @@ func TestTransferReadWriteErrors(t *testing.T) {
|
|
err = f.Close()
|
|
err = f.Close()
|
|
assert.NoError(t, err)
|
|
assert.NoError(t, err)
|
|
}
|
|
}
|
|
- davFile = newWebDavFile(baseTransfer, nil, nil, nil)
|
|
|
|
|
|
+ davFile = newWebDavFile(baseTransfer, nil, nil)
|
|
davFile.reader = f
|
|
davFile.reader = f
|
|
err = davFile.Close()
|
|
err = davFile.Close()
|
|
assert.EqualError(t, err, common.ErrGenericFailure.Error())
|
|
assert.EqualError(t, err, common.ErrGenericFailure.Error())
|
|
@@ -514,7 +526,7 @@ func TestTransferReadWriteErrors(t *testing.T) {
|
|
|
|
|
|
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFile,
|
|
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFile,
|
|
common.TransferDownload, 0, 0, 0, false, fs)
|
|
common.TransferDownload, 0, 0, 0, false, fs)
|
|
- davFile = newWebDavFile(baseTransfer, nil, nil, nil)
|
|
|
|
|
|
+ davFile = newWebDavFile(baseTransfer, nil, nil)
|
|
davFile.writer = f
|
|
davFile.writer = f
|
|
err = davFile.Close()
|
|
err = davFile.Close()
|
|
assert.EqualError(t, err, common.ErrGenericFailure.Error())
|
|
assert.EqualError(t, err, common.ErrGenericFailure.Error())
|
|
@@ -534,9 +546,10 @@ func TestTransferSeek(t *testing.T) {
|
|
BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, user, fs),
|
|
BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, user, fs),
|
|
}
|
|
}
|
|
testFilePath := filepath.Join(user.HomeDir, testFile)
|
|
testFilePath := filepath.Join(user.HomeDir, testFile)
|
|
|
|
+ testFileContents := []byte("content")
|
|
baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFile,
|
|
baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFile,
|
|
common.TransferUpload, 0, 0, 0, false, fs)
|
|
common.TransferUpload, 0, 0, 0, false, fs)
|
|
- davFile := newWebDavFile(baseTransfer, nil, nil, nil)
|
|
|
|
|
|
+ davFile := newWebDavFile(baseTransfer, nil, nil)
|
|
_, err := davFile.Seek(0, io.SeekStart)
|
|
_, err := davFile.Seek(0, io.SeekStart)
|
|
assert.EqualError(t, err, common.ErrOpUnsupported.Error())
|
|
assert.EqualError(t, err, common.ErrOpUnsupported.Error())
|
|
err = davFile.Close()
|
|
err = davFile.Close()
|
|
@@ -544,12 +557,12 @@ func TestTransferSeek(t *testing.T) {
|
|
|
|
|
|
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFile,
|
|
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFile,
|
|
common.TransferDownload, 0, 0, 0, false, fs)
|
|
common.TransferDownload, 0, 0, 0, false, fs)
|
|
- davFile = newWebDavFile(baseTransfer, nil, nil, nil)
|
|
|
|
|
|
+ davFile = newWebDavFile(baseTransfer, nil, nil)
|
|
_, err = davFile.Seek(0, io.SeekCurrent)
|
|
_, err = davFile.Seek(0, io.SeekCurrent)
|
|
assert.True(t, os.IsNotExist(err))
|
|
assert.True(t, os.IsNotExist(err))
|
|
davFile.Connection.RemoveTransfer(davFile.BaseTransfer)
|
|
davFile.Connection.RemoveTransfer(davFile.BaseTransfer)
|
|
|
|
|
|
- err = ioutil.WriteFile(testFilePath, []byte("content"), os.ModePerm)
|
|
|
|
|
|
+ err = ioutil.WriteFile(testFilePath, testFileContents, os.ModePerm)
|
|
assert.NoError(t, err)
|
|
assert.NoError(t, err)
|
|
f, err := os.Open(testFilePath)
|
|
f, err := os.Open(testFilePath)
|
|
if assert.NoError(t, err) {
|
|
if assert.NoError(t, err) {
|
|
@@ -558,44 +571,55 @@ func TestTransferSeek(t *testing.T) {
|
|
}
|
|
}
|
|
baseTransfer = common.NewBaseTransfer(f, connection.BaseConnection, nil, testFilePath, testFile,
|
|
baseTransfer = common.NewBaseTransfer(f, connection.BaseConnection, nil, testFilePath, testFile,
|
|
common.TransferDownload, 0, 0, 0, false, fs)
|
|
common.TransferDownload, 0, 0, 0, false, fs)
|
|
- davFile = newWebDavFile(baseTransfer, nil, nil, nil)
|
|
|
|
|
|
+ davFile = newWebDavFile(baseTransfer, nil, nil)
|
|
_, err = davFile.Seek(0, io.SeekStart)
|
|
_, err = davFile.Seek(0, io.SeekStart)
|
|
assert.Error(t, err)
|
|
assert.Error(t, err)
|
|
davFile.Connection.RemoveTransfer(davFile.BaseTransfer)
|
|
davFile.Connection.RemoveTransfer(davFile.BaseTransfer)
|
|
|
|
|
|
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFile,
|
|
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFile,
|
|
common.TransferDownload, 0, 0, 0, false, fs)
|
|
common.TransferDownload, 0, 0, 0, false, fs)
|
|
- davFile = newWebDavFile(baseTransfer, nil, nil, nil)
|
|
|
|
- davFile.reader = f
|
|
|
|
|
|
+ davFile = newWebDavFile(baseTransfer, nil, nil)
|
|
res, err := davFile.Seek(0, io.SeekStart)
|
|
res, err := davFile.Seek(0, io.SeekStart)
|
|
assert.NoError(t, err)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, int64(0), res)
|
|
assert.Equal(t, int64(0), res)
|
|
davFile.Connection.RemoveTransfer(davFile.BaseTransfer)
|
|
davFile.Connection.RemoveTransfer(davFile.BaseTransfer)
|
|
|
|
|
|
- info, err := os.Stat(testFilePath)
|
|
|
|
- assert.NoError(t, err)
|
|
|
|
- davFile = newWebDavFile(baseTransfer, nil, nil, info)
|
|
|
|
- davFile.reader = f
|
|
|
|
|
|
+ davFile = newWebDavFile(baseTransfer, nil, nil)
|
|
res, err = davFile.Seek(0, io.SeekEnd)
|
|
res, err = davFile.Seek(0, io.SeekEnd)
|
|
assert.NoError(t, err)
|
|
assert.NoError(t, err)
|
|
- assert.Equal(t, int64(7), res)
|
|
|
|
|
|
+ assert.Equal(t, int64(len(testFileContents)), res)
|
|
|
|
+ err = davFile.updateStatInfo()
|
|
|
|
+ assert.Nil(t, err)
|
|
|
|
|
|
- davFile = newWebDavFile(baseTransfer, nil, nil, info)
|
|
|
|
|
|
+ baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath+"1", testFile,
|
|
|
|
+ common.TransferDownload, 0, 0, 0, false, fs)
|
|
|
|
+ davFile = newWebDavFile(baseTransfer, nil, nil)
|
|
|
|
+ _, err = davFile.Seek(0, io.SeekEnd)
|
|
|
|
+ assert.True(t, os.IsNotExist(err))
|
|
|
|
+ davFile.Connection.RemoveTransfer(davFile.BaseTransfer)
|
|
|
|
+
|
|
|
|
+ baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFile,
|
|
|
|
+ common.TransferDownload, 0, 0, 0, false, fs)
|
|
|
|
+ davFile = newWebDavFile(baseTransfer, nil, nil)
|
|
davFile.reader = f
|
|
davFile.reader = f
|
|
davFile.Fs = newMockOsFs(nil, true, fs.ConnectionID(), user.GetHomeDir())
|
|
davFile.Fs = newMockOsFs(nil, true, fs.ConnectionID(), user.GetHomeDir())
|
|
res, err = davFile.Seek(2, io.SeekStart)
|
|
res, err = davFile.Seek(2, io.SeekStart)
|
|
assert.NoError(t, err)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, int64(2), res)
|
|
assert.Equal(t, int64(2), res)
|
|
|
|
|
|
- davFile = newWebDavFile(baseTransfer, nil, nil, info)
|
|
|
|
|
|
+ davFile = newWebDavFile(baseTransfer, nil, nil)
|
|
davFile.Fs = newMockOsFs(nil, true, fs.ConnectionID(), user.GetHomeDir())
|
|
davFile.Fs = newMockOsFs(nil, true, fs.ConnectionID(), user.GetHomeDir())
|
|
res, err = davFile.Seek(2, io.SeekEnd)
|
|
res, err = davFile.Seek(2, io.SeekEnd)
|
|
assert.NoError(t, err)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, int64(5), res)
|
|
assert.Equal(t, int64(5), res)
|
|
|
|
|
|
- davFile = newWebDavFile(baseTransfer, nil, nil, nil)
|
|
|
|
|
|
+ baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath+"1", testFile,
|
|
|
|
+ common.TransferDownload, 0, 0, 0, false, fs)
|
|
|
|
+
|
|
|
|
+ davFile = newWebDavFile(baseTransfer, nil, nil)
|
|
|
|
+ davFile.Fs = newMockOsFs(nil, true, fs.ConnectionID(), user.GetHomeDir())
|
|
res, err = davFile.Seek(2, io.SeekEnd)
|
|
res, err = davFile.Seek(2, io.SeekEnd)
|
|
- assert.EqualError(t, err, "unable to get file size, seek from end not possible")
|
|
|
|
|
|
+ assert.True(t, os.IsNotExist(err))
|
|
assert.Equal(t, int64(0), res)
|
|
assert.Equal(t, int64(0), res)
|
|
|
|
|
|
assert.Len(t, common.Connections.GetStats(), 0)
|
|
assert.Len(t, common.Connections.GetStats(), 0)
|
|
@@ -622,9 +646,11 @@ func TestBasicUsersCache(t *testing.T) {
|
|
c := &Configuration{
|
|
c := &Configuration{
|
|
BindPort: 9000,
|
|
BindPort: 9000,
|
|
Cache: Cache{
|
|
Cache: Cache{
|
|
- Enabled: true,
|
|
|
|
- MaxSize: 50,
|
|
|
|
- ExpirationTime: 1,
|
|
|
|
|
|
+ Users: UsersCacheConfig{
|
|
|
|
+ Enabled: true,
|
|
|
|
+ MaxSize: 50,
|
|
|
|
+ ExpirationTime: 1,
|
|
|
|
+ },
|
|
},
|
|
},
|
|
}
|
|
}
|
|
server, err := newServer(c, configDir)
|
|
server, err := newServer(c, configDir)
|
|
@@ -633,48 +659,48 @@ func TestBasicUsersCache(t *testing.T) {
|
|
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user.Username), nil)
|
|
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user.Username), nil)
|
|
assert.NoError(t, err)
|
|
assert.NoError(t, err)
|
|
|
|
|
|
- _, _, err = server.authenticate(req)
|
|
|
|
|
|
+ _, _, _, err = server.authenticate(req) //nolint:dogsled
|
|
assert.Error(t, err)
|
|
assert.Error(t, err)
|
|
|
|
|
|
now := time.Now()
|
|
now := time.Now()
|
|
req.SetBasicAuth(username, password)
|
|
req.SetBasicAuth(username, password)
|
|
- _, isCached, err := server.authenticate(req)
|
|
|
|
|
|
+ _, isCached, _, err := server.authenticate(req)
|
|
assert.NoError(t, err)
|
|
assert.NoError(t, err)
|
|
assert.False(t, isCached)
|
|
assert.False(t, isCached)
|
|
// now the user should be cached
|
|
// now the user should be cached
|
|
- var cachedUser dataprovider.CachedUser
|
|
|
|
|
|
+ var cachedUser *dataprovider.CachedUser
|
|
result, ok := dataprovider.GetCachedWebDAVUser(username)
|
|
result, ok := dataprovider.GetCachedWebDAVUser(username)
|
|
if assert.True(t, ok) {
|
|
if assert.True(t, ok) {
|
|
- cachedUser = result.(dataprovider.CachedUser)
|
|
|
|
|
|
+ cachedUser = result.(*dataprovider.CachedUser)
|
|
assert.False(t, cachedUser.IsExpired())
|
|
assert.False(t, cachedUser.IsExpired())
|
|
- assert.True(t, cachedUser.Expiration.After(now.Add(time.Duration(c.Cache.ExpirationTime)*time.Minute)))
|
|
|
|
|
|
+ assert.True(t, cachedUser.Expiration.After(now.Add(time.Duration(c.Cache.Users.ExpirationTime)*time.Minute)))
|
|
// authenticate must return the cached user now
|
|
// authenticate must return the cached user now
|
|
- authUser, isCached, err := server.authenticate(req)
|
|
|
|
|
|
+ authUser, isCached, _, err := server.authenticate(req)
|
|
assert.NoError(t, err)
|
|
assert.NoError(t, err)
|
|
assert.True(t, isCached)
|
|
assert.True(t, isCached)
|
|
assert.Equal(t, cachedUser.User, authUser)
|
|
assert.Equal(t, cachedUser.User, authUser)
|
|
}
|
|
}
|
|
// a wrong password must fail
|
|
// a wrong password must fail
|
|
req.SetBasicAuth(username, "wrong")
|
|
req.SetBasicAuth(username, "wrong")
|
|
- _, _, err = server.authenticate(req)
|
|
|
|
|
|
+ _, _, _, err = server.authenticate(req) //nolint:dogsled
|
|
assert.EqualError(t, err, dataprovider.ErrInvalidCredentials.Error())
|
|
assert.EqualError(t, err, dataprovider.ErrInvalidCredentials.Error())
|
|
req.SetBasicAuth(username, password)
|
|
req.SetBasicAuth(username, password)
|
|
|
|
|
|
// force cached user expiration
|
|
// force cached user expiration
|
|
cachedUser.Expiration = now
|
|
cachedUser.Expiration = now
|
|
- dataprovider.CacheWebDAVUser(cachedUser, c.Cache.MaxSize)
|
|
|
|
|
|
+ dataprovider.CacheWebDAVUser(cachedUser, c.Cache.Users.MaxSize)
|
|
result, ok = dataprovider.GetCachedWebDAVUser(username)
|
|
result, ok = dataprovider.GetCachedWebDAVUser(username)
|
|
if assert.True(t, ok) {
|
|
if assert.True(t, ok) {
|
|
- cachedUser = result.(dataprovider.CachedUser)
|
|
|
|
|
|
+ cachedUser = result.(*dataprovider.CachedUser)
|
|
assert.True(t, cachedUser.IsExpired())
|
|
assert.True(t, cachedUser.IsExpired())
|
|
}
|
|
}
|
|
// now authenticate should get the user from the data provider and update the cache
|
|
// now authenticate should get the user from the data provider and update the cache
|
|
- _, isCached, err = server.authenticate(req)
|
|
|
|
|
|
+ _, isCached, _, err = server.authenticate(req)
|
|
assert.NoError(t, err)
|
|
assert.NoError(t, err)
|
|
assert.False(t, isCached)
|
|
assert.False(t, isCached)
|
|
result, ok = dataprovider.GetCachedWebDAVUser(username)
|
|
result, ok = dataprovider.GetCachedWebDAVUser(username)
|
|
if assert.True(t, ok) {
|
|
if assert.True(t, ok) {
|
|
- cachedUser = result.(dataprovider.CachedUser)
|
|
|
|
|
|
+ cachedUser = result.(*dataprovider.CachedUser)
|
|
assert.False(t, cachedUser.IsExpired())
|
|
assert.False(t, cachedUser.IsExpired())
|
|
}
|
|
}
|
|
// cache is invalidated after a user modification
|
|
// cache is invalidated after a user modification
|
|
@@ -683,7 +709,7 @@ func TestBasicUsersCache(t *testing.T) {
|
|
_, ok = dataprovider.GetCachedWebDAVUser(username)
|
|
_, ok = dataprovider.GetCachedWebDAVUser(username)
|
|
assert.False(t, ok)
|
|
assert.False(t, ok)
|
|
|
|
|
|
- _, isCached, err = server.authenticate(req)
|
|
|
|
|
|
+ _, isCached, _, err = server.authenticate(req)
|
|
assert.NoError(t, err)
|
|
assert.NoError(t, err)
|
|
assert.False(t, isCached)
|
|
assert.False(t, isCached)
|
|
_, ok = dataprovider.GetCachedWebDAVUser(username)
|
|
_, ok = dataprovider.GetCachedWebDAVUser(username)
|
|
@@ -725,9 +751,11 @@ func TestUsersCacheSizeAndExpiration(t *testing.T) {
|
|
c := &Configuration{
|
|
c := &Configuration{
|
|
BindPort: 9000,
|
|
BindPort: 9000,
|
|
Cache: Cache{
|
|
Cache: Cache{
|
|
- Enabled: true,
|
|
|
|
- MaxSize: 3,
|
|
|
|
- ExpirationTime: 1,
|
|
|
|
|
|
+ Users: UsersCacheConfig{
|
|
|
|
+ Enabled: true,
|
|
|
|
+ MaxSize: 3,
|
|
|
|
+ ExpirationTime: 1,
|
|
|
|
+ },
|
|
},
|
|
},
|
|
}
|
|
}
|
|
server, err := newServer(c, configDir)
|
|
server, err := newServer(c, configDir)
|
|
@@ -736,21 +764,21 @@ func TestUsersCacheSizeAndExpiration(t *testing.T) {
|
|
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user1.Username), nil)
|
|
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user1.Username), nil)
|
|
assert.NoError(t, err)
|
|
assert.NoError(t, err)
|
|
req.SetBasicAuth(user1.Username, password+"1")
|
|
req.SetBasicAuth(user1.Username, password+"1")
|
|
- _, isCached, err := server.authenticate(req)
|
|
|
|
|
|
+ _, isCached, _, err := server.authenticate(req)
|
|
assert.NoError(t, err)
|
|
assert.NoError(t, err)
|
|
assert.False(t, isCached)
|
|
assert.False(t, isCached)
|
|
|
|
|
|
req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user2.Username), nil)
|
|
req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user2.Username), nil)
|
|
assert.NoError(t, err)
|
|
assert.NoError(t, err)
|
|
req.SetBasicAuth(user2.Username, password+"2")
|
|
req.SetBasicAuth(user2.Username, password+"2")
|
|
- _, isCached, err = server.authenticate(req)
|
|
|
|
|
|
+ _, isCached, _, err = server.authenticate(req)
|
|
assert.NoError(t, err)
|
|
assert.NoError(t, err)
|
|
assert.False(t, isCached)
|
|
assert.False(t, isCached)
|
|
|
|
|
|
req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user3.Username), nil)
|
|
req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user3.Username), nil)
|
|
assert.NoError(t, err)
|
|
assert.NoError(t, err)
|
|
req.SetBasicAuth(user3.Username, password+"3")
|
|
req.SetBasicAuth(user3.Username, password+"3")
|
|
- _, isCached, err = server.authenticate(req)
|
|
|
|
|
|
+ _, isCached, _, err = server.authenticate(req)
|
|
assert.NoError(t, err)
|
|
assert.NoError(t, err)
|
|
assert.False(t, isCached)
|
|
assert.False(t, isCached)
|
|
|
|
|
|
@@ -765,7 +793,7 @@ func TestUsersCacheSizeAndExpiration(t *testing.T) {
|
|
req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user4.Username), nil)
|
|
req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user4.Username), nil)
|
|
assert.NoError(t, err)
|
|
assert.NoError(t, err)
|
|
req.SetBasicAuth(user4.Username, password+"4")
|
|
req.SetBasicAuth(user4.Username, password+"4")
|
|
- _, isCached, err = server.authenticate(req)
|
|
|
|
|
|
+ _, isCached, _, err = server.authenticate(req)
|
|
assert.NoError(t, err)
|
|
assert.NoError(t, err)
|
|
assert.False(t, isCached)
|
|
assert.False(t, isCached)
|
|
// user1, the first cached, should be removed now
|
|
// user1, the first cached, should be removed now
|
|
@@ -782,7 +810,7 @@ func TestUsersCacheSizeAndExpiration(t *testing.T) {
|
|
req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user1.Username), nil)
|
|
req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user1.Username), nil)
|
|
assert.NoError(t, err)
|
|
assert.NoError(t, err)
|
|
req.SetBasicAuth(user1.Username, password+"1")
|
|
req.SetBasicAuth(user1.Username, password+"1")
|
|
- _, isCached, err = server.authenticate(req)
|
|
|
|
|
|
+ _, isCached, _, err = server.authenticate(req)
|
|
assert.NoError(t, err)
|
|
assert.NoError(t, err)
|
|
assert.False(t, isCached)
|
|
assert.False(t, isCached)
|
|
_, ok = dataprovider.GetCachedWebDAVUser(user2.Username)
|
|
_, ok = dataprovider.GetCachedWebDAVUser(user2.Username)
|
|
@@ -798,7 +826,7 @@ func TestUsersCacheSizeAndExpiration(t *testing.T) {
|
|
req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user2.Username), nil)
|
|
req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user2.Username), nil)
|
|
assert.NoError(t, err)
|
|
assert.NoError(t, err)
|
|
req.SetBasicAuth(user2.Username, password+"2")
|
|
req.SetBasicAuth(user2.Username, password+"2")
|
|
- _, isCached, err = server.authenticate(req)
|
|
|
|
|
|
+ _, isCached, _, err = server.authenticate(req)
|
|
assert.NoError(t, err)
|
|
assert.NoError(t, err)
|
|
assert.False(t, isCached)
|
|
assert.False(t, isCached)
|
|
_, ok = dataprovider.GetCachedWebDAVUser(user3.Username)
|
|
_, ok = dataprovider.GetCachedWebDAVUser(user3.Username)
|
|
@@ -814,7 +842,7 @@ func TestUsersCacheSizeAndExpiration(t *testing.T) {
|
|
req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user3.Username), nil)
|
|
req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user3.Username), nil)
|
|
assert.NoError(t, err)
|
|
assert.NoError(t, err)
|
|
req.SetBasicAuth(user3.Username, password+"3")
|
|
req.SetBasicAuth(user3.Username, password+"3")
|
|
- _, isCached, err = server.authenticate(req)
|
|
|
|
|
|
+ _, isCached, _, err = server.authenticate(req)
|
|
assert.NoError(t, err)
|
|
assert.NoError(t, err)
|
|
assert.False(t, isCached)
|
|
assert.False(t, isCached)
|
|
_, ok = dataprovider.GetCachedWebDAVUser(user4.Username)
|
|
_, ok = dataprovider.GetCachedWebDAVUser(user4.Username)
|
|
@@ -835,14 +863,14 @@ func TestUsersCacheSizeAndExpiration(t *testing.T) {
|
|
req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user4.Username), nil)
|
|
req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user4.Username), nil)
|
|
assert.NoError(t, err)
|
|
assert.NoError(t, err)
|
|
req.SetBasicAuth(user4.Username, password+"4")
|
|
req.SetBasicAuth(user4.Username, password+"4")
|
|
- _, isCached, err = server.authenticate(req)
|
|
|
|
|
|
+ _, isCached, _, err = server.authenticate(req)
|
|
assert.NoError(t, err)
|
|
assert.NoError(t, err)
|
|
assert.False(t, isCached)
|
|
assert.False(t, isCached)
|
|
|
|
|
|
req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user1.Username), nil)
|
|
req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user1.Username), nil)
|
|
assert.NoError(t, err)
|
|
assert.NoError(t, err)
|
|
req.SetBasicAuth(user1.Username, password+"1")
|
|
req.SetBasicAuth(user1.Username, password+"1")
|
|
- _, isCached, err = server.authenticate(req)
|
|
|
|
|
|
+ _, isCached, _, err = server.authenticate(req)
|
|
assert.NoError(t, err)
|
|
assert.NoError(t, err)
|
|
assert.False(t, isCached)
|
|
assert.False(t, isCached)
|
|
_, ok = dataprovider.GetCachedWebDAVUser(user2.Username)
|
|
_, ok = dataprovider.GetCachedWebDAVUser(user2.Username)
|
|
@@ -874,3 +902,20 @@ func TestRecoverer(t *testing.T) {
|
|
server.ServeHTTP(rr, nil)
|
|
server.ServeHTTP(rr, nil)
|
|
assert.Equal(t, http.StatusInternalServerError, rr.Code)
|
|
assert.Equal(t, http.StatusInternalServerError, rr.Code)
|
|
}
|
|
}
|
|
|
|
+
|
|
|
|
+func TestMimeCache(t *testing.T) {
|
|
|
|
+ cache := mimeCache{
|
|
|
|
+ maxSize: 0,
|
|
|
|
+ mimeTypes: make(map[string]string),
|
|
|
|
+ }
|
|
|
|
+ cache.addMimeToCache(".zip", "application/zip")
|
|
|
|
+ mtype := cache.getMimeFromCache(".zip")
|
|
|
|
+ assert.Equal(t, "", mtype)
|
|
|
|
+ cache.maxSize = 1
|
|
|
|
+ cache.addMimeToCache(".zip", "application/zip")
|
|
|
|
+ mtype = cache.getMimeFromCache(".zip")
|
|
|
|
+ assert.Equal(t, "application/zip", mtype)
|
|
|
|
+ cache.addMimeToCache(".jpg", "image/jpeg")
|
|
|
|
+ mtype = cache.getMimeFromCache(".jpg")
|
|
|
|
+ assert.Equal(t, "", mtype)
|
|
|
|
+}
|