|
@@ -16,6 +16,7 @@ import (
|
|
"github.com/eikenb/pipeat"
|
|
"github.com/eikenb/pipeat"
|
|
"github.com/pkg/sftp"
|
|
"github.com/pkg/sftp"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/assert"
|
|
|
|
+ "github.com/stretchr/testify/require"
|
|
"golang.org/x/crypto/ssh"
|
|
"golang.org/x/crypto/ssh"
|
|
|
|
|
|
"github.com/drakkan/sftpgo/common"
|
|
"github.com/drakkan/sftpgo/common"
|
|
@@ -159,7 +160,7 @@ func TestUploadResumeInvalidOffset(t *testing.T) {
|
|
fs := vfs.NewOsFs("", os.TempDir(), nil)
|
|
fs := vfs.NewOsFs("", os.TempDir(), nil)
|
|
conn := common.NewBaseConnection("", common.ProtocolSFTP, user, fs)
|
|
conn := common.NewBaseConnection("", common.ProtocolSFTP, user, fs)
|
|
baseTransfer := common.NewBaseTransfer(file, conn, nil, file.Name(), testfile, common.TransferUpload, 10, 0, 0, false, fs)
|
|
baseTransfer := common.NewBaseTransfer(file, conn, nil, file.Name(), testfile, common.TransferUpload, 10, 0, 0, false, fs)
|
|
- transfer := newTransfer(baseTransfer, nil, nil)
|
|
|
|
|
|
+ transfer := newTransfer(baseTransfer, nil, nil, nil)
|
|
_, err = transfer.WriteAt([]byte("test"), 0)
|
|
_, err = transfer.WriteAt([]byte("test"), 0)
|
|
assert.Error(t, err, "upload with invalid offset must fail")
|
|
assert.Error(t, err, "upload with invalid offset must fail")
|
|
if assert.Error(t, transfer.ErrTransfer) {
|
|
if assert.Error(t, transfer.ErrTransfer) {
|
|
@@ -187,7 +188,7 @@ func TestReadWriteErrors(t *testing.T) {
|
|
fs := vfs.NewOsFs("", os.TempDir(), nil)
|
|
fs := vfs.NewOsFs("", os.TempDir(), nil)
|
|
conn := common.NewBaseConnection("", common.ProtocolSFTP, user, fs)
|
|
conn := common.NewBaseConnection("", common.ProtocolSFTP, user, fs)
|
|
baseTransfer := common.NewBaseTransfer(file, conn, nil, file.Name(), testfile, common.TransferDownload, 0, 0, 0, false, fs)
|
|
baseTransfer := common.NewBaseTransfer(file, conn, nil, file.Name(), testfile, common.TransferDownload, 0, 0, 0, false, fs)
|
|
- transfer := newTransfer(baseTransfer, nil, nil)
|
|
|
|
|
|
+ transfer := newTransfer(baseTransfer, nil, nil, nil)
|
|
err = file.Close()
|
|
err = file.Close()
|
|
assert.NoError(t, err)
|
|
assert.NoError(t, err)
|
|
_, err = transfer.WriteAt([]byte("test"), 0)
|
|
_, err = transfer.WriteAt([]byte("test"), 0)
|
|
@@ -201,8 +202,8 @@ func TestReadWriteErrors(t *testing.T) {
|
|
r, _, err := pipeat.Pipe()
|
|
r, _, err := pipeat.Pipe()
|
|
assert.NoError(t, err)
|
|
assert.NoError(t, err)
|
|
baseTransfer = common.NewBaseTransfer(nil, conn, nil, file.Name(), testfile, common.TransferDownload, 0, 0, 0, false, fs)
|
|
baseTransfer = common.NewBaseTransfer(nil, conn, nil, file.Name(), testfile, common.TransferDownload, 0, 0, 0, false, fs)
|
|
- transfer = newTransfer(baseTransfer, nil, r)
|
|
|
|
- err = transfer.closeIO()
|
|
|
|
|
|
+ transfer = newTransfer(baseTransfer, nil, r, nil)
|
|
|
|
+ err = transfer.Close()
|
|
assert.NoError(t, err)
|
|
assert.NoError(t, err)
|
|
_, err = transfer.ReadAt(buf, 0)
|
|
_, err = transfer.ReadAt(buf, 0)
|
|
assert.Error(t, err, "reading from a closed pipe must fail")
|
|
assert.Error(t, err, "reading from a closed pipe must fail")
|
|
@@ -211,7 +212,7 @@ func TestReadWriteErrors(t *testing.T) {
|
|
assert.NoError(t, err)
|
|
assert.NoError(t, err)
|
|
pipeWriter := vfs.NewPipeWriter(w)
|
|
pipeWriter := vfs.NewPipeWriter(w)
|
|
baseTransfer = common.NewBaseTransfer(nil, conn, nil, file.Name(), testfile, common.TransferDownload, 0, 0, 0, false, fs)
|
|
baseTransfer = common.NewBaseTransfer(nil, conn, nil, file.Name(), testfile, common.TransferDownload, 0, 0, 0, false, fs)
|
|
- transfer = newTransfer(baseTransfer, pipeWriter, nil)
|
|
|
|
|
|
+ transfer = newTransfer(baseTransfer, pipeWriter, nil, nil)
|
|
|
|
|
|
err = r.Close()
|
|
err = r.Close()
|
|
assert.NoError(t, err)
|
|
assert.NoError(t, err)
|
|
@@ -224,9 +225,12 @@ func TestReadWriteErrors(t *testing.T) {
|
|
assert.EqualError(t, err, errFake.Error())
|
|
assert.EqualError(t, err, errFake.Error())
|
|
_, err = transfer.WriteAt([]byte("test"), 0)
|
|
_, err = transfer.WriteAt([]byte("test"), 0)
|
|
assert.Error(t, err, "writing to closed pipe must fail")
|
|
assert.Error(t, err, "writing to closed pipe must fail")
|
|
|
|
+ err = transfer.BaseTransfer.Close()
|
|
|
|
+ assert.EqualError(t, err, errFake.Error())
|
|
|
|
|
|
err = os.Remove(testfile)
|
|
err = os.Remove(testfile)
|
|
assert.NoError(t, err)
|
|
assert.NoError(t, err)
|
|
|
|
+ assert.Len(t, conn.GetTransfers(), 0)
|
|
}
|
|
}
|
|
|
|
|
|
func TestUnsupportedListOP(t *testing.T) {
|
|
func TestUnsupportedListOP(t *testing.T) {
|
|
@@ -254,7 +258,7 @@ func TestTransferCancelFn(t *testing.T) {
|
|
fs := vfs.NewOsFs("", os.TempDir(), nil)
|
|
fs := vfs.NewOsFs("", os.TempDir(), nil)
|
|
conn := common.NewBaseConnection("", common.ProtocolSFTP, user, fs)
|
|
conn := common.NewBaseConnection("", common.ProtocolSFTP, user, fs)
|
|
baseTransfer := common.NewBaseTransfer(file, conn, cancelFn, file.Name(), testfile, common.TransferDownload, 0, 0, 0, false, fs)
|
|
baseTransfer := common.NewBaseTransfer(file, conn, cancelFn, file.Name(), testfile, common.TransferDownload, 0, 0, 0, false, fs)
|
|
- transfer := newTransfer(baseTransfer, nil, nil)
|
|
|
|
|
|
+ transfer := newTransfer(baseTransfer, nil, nil, nil)
|
|
|
|
|
|
errFake := errors.New("fake error, this will trigger cancelFn")
|
|
errFake := errors.New("fake error, this will trigger cancelFn")
|
|
transfer.TransferError(errFake)
|
|
transfer.TransferError(errFake)
|
|
@@ -293,7 +297,7 @@ func TestMockFsErrors(t *testing.T) {
|
|
flags.Write = true
|
|
flags.Write = true
|
|
flags.Trunc = false
|
|
flags.Trunc = false
|
|
flags.Append = true
|
|
flags.Append = true
|
|
- _, err = c.handleSFTPUploadToExistingFile(flags, testfile, testfile, 0, "/testfile")
|
|
|
|
|
|
+ _, err = c.handleSFTPUploadToExistingFile(flags, testfile, testfile, 0, "/testfile", nil)
|
|
assert.EqualError(t, err, sftp.ErrSSHFxOpUnsupported.Error())
|
|
assert.EqualError(t, err, sftp.ErrSSHFxOpUnsupported.Error())
|
|
|
|
|
|
fs = newMockOsFs(errFake, nil, false, "123", os.TempDir())
|
|
fs = newMockOsFs(errFake, nil, false, "123", os.TempDir())
|
|
@@ -321,18 +325,18 @@ func TestUploadFiles(t *testing.T) {
|
|
var flags sftp.FileOpenFlags
|
|
var flags sftp.FileOpenFlags
|
|
flags.Write = true
|
|
flags.Write = true
|
|
flags.Trunc = true
|
|
flags.Trunc = true
|
|
- _, err := c.handleSFTPUploadToExistingFile(flags, "missing_path", "other_missing_path", 0, "/missing_path")
|
|
|
|
|
|
+ _, err := c.handleSFTPUploadToExistingFile(flags, "missing_path", "other_missing_path", 0, "/missing_path", nil)
|
|
assert.Error(t, err, "upload to existing file must fail if one or both paths are invalid")
|
|
assert.Error(t, err, "upload to existing file must fail if one or both paths are invalid")
|
|
|
|
|
|
common.Config.UploadMode = common.UploadModeStandard
|
|
common.Config.UploadMode = common.UploadModeStandard
|
|
- _, err = c.handleSFTPUploadToExistingFile(flags, "missing_path", "other_missing_path", 0, "/missing_path")
|
|
|
|
|
|
+ _, err = c.handleSFTPUploadToExistingFile(flags, "missing_path", "other_missing_path", 0, "/missing_path", nil)
|
|
assert.Error(t, err, "upload to existing file must fail if one or both paths are invalid")
|
|
assert.Error(t, err, "upload to existing file must fail if one or both paths are invalid")
|
|
|
|
|
|
missingFile := "missing/relative/file.txt"
|
|
missingFile := "missing/relative/file.txt"
|
|
if runtime.GOOS == osWindows {
|
|
if runtime.GOOS == osWindows {
|
|
missingFile = "missing\\relative\\file.txt"
|
|
missingFile = "missing\\relative\\file.txt"
|
|
}
|
|
}
|
|
- _, err = c.handleSFTPUploadToNewFile(".", missingFile, "/missing")
|
|
|
|
|
|
+ _, err = c.handleSFTPUploadToNewFile(".", missingFile, "/missing", nil)
|
|
assert.Error(t, err, "upload new file in missing path must fail")
|
|
assert.Error(t, err, "upload new file in missing path must fail")
|
|
|
|
|
|
c.BaseConnection.Fs = newMockOsFs(nil, nil, false, "123", os.TempDir())
|
|
c.BaseConnection.Fs = newMockOsFs(nil, nil, false, "123", os.TempDir())
|
|
@@ -341,7 +345,7 @@ func TestUploadFiles(t *testing.T) {
|
|
err = f.Close()
|
|
err = f.Close()
|
|
assert.NoError(t, err)
|
|
assert.NoError(t, err)
|
|
|
|
|
|
- tr, err := c.handleSFTPUploadToExistingFile(flags, f.Name(), f.Name(), 123, f.Name())
|
|
|
|
|
|
+ tr, err := c.handleSFTPUploadToExistingFile(flags, f.Name(), f.Name(), 123, f.Name(), nil)
|
|
if assert.NoError(t, err) {
|
|
if assert.NoError(t, err) {
|
|
transfer := tr.(*transfer)
|
|
transfer := tr.(*transfer)
|
|
transfers := c.GetTransfers()
|
|
transfers := c.GetTransfers()
|
|
@@ -990,7 +994,7 @@ func TestSystemCommandErrors(t *testing.T) {
|
|
sshCmd.connection.channel = &mockSSHChannel
|
|
sshCmd.connection.channel = &mockSSHChannel
|
|
baseTransfer := common.NewBaseTransfer(nil, sshCmd.connection.BaseConnection, nil, "", "", common.TransferDownload,
|
|
baseTransfer := common.NewBaseTransfer(nil, sshCmd.connection.BaseConnection, nil, "", "", common.TransferDownload,
|
|
0, 0, 0, false, fs)
|
|
0, 0, 0, false, fs)
|
|
- transfer := newTransfer(baseTransfer, nil, nil)
|
|
|
|
|
|
+ transfer := newTransfer(baseTransfer, nil, nil, nil)
|
|
destBuff := make([]byte, 65535)
|
|
destBuff := make([]byte, 65535)
|
|
dst := bytes.NewBuffer(destBuff)
|
|
dst := bytes.NewBuffer(destBuff)
|
|
_, err = transfer.copyFromReaderToWriter(dst, sshCmd.connection.channel)
|
|
_, err = transfer.copyFromReaderToWriter(dst, sshCmd.connection.channel)
|
|
@@ -1542,7 +1546,7 @@ func TestSCPUploadFiledata(t *testing.T) {
|
|
|
|
|
|
baseTransfer := common.NewBaseTransfer(file, scpCommand.connection.BaseConnection, nil, file.Name(),
|
|
baseTransfer := common.NewBaseTransfer(file, scpCommand.connection.BaseConnection, nil, file.Name(),
|
|
"/"+testfile, common.TransferDownload, 0, 0, 0, true, fs)
|
|
"/"+testfile, common.TransferDownload, 0, 0, 0, true, fs)
|
|
- transfer := newTransfer(baseTransfer, nil, nil)
|
|
|
|
|
|
+ transfer := newTransfer(baseTransfer, nil, nil, nil)
|
|
|
|
|
|
err = scpCommand.getUploadFileData(2, transfer)
|
|
err = scpCommand.getUploadFileData(2, transfer)
|
|
assert.Error(t, err, "upload must fail, we send a fake write error message")
|
|
assert.Error(t, err, "upload must fail, we send a fake write error message")
|
|
@@ -1574,7 +1578,7 @@ func TestSCPUploadFiledata(t *testing.T) {
|
|
file, err = os.Create(testfile)
|
|
file, err = os.Create(testfile)
|
|
assert.NoError(t, err)
|
|
assert.NoError(t, err)
|
|
baseTransfer.File = file
|
|
baseTransfer.File = file
|
|
- transfer = newTransfer(baseTransfer, nil, nil)
|
|
|
|
|
|
+ transfer = newTransfer(baseTransfer, nil, nil, nil)
|
|
transfer.Connection.AddTransfer(transfer)
|
|
transfer.Connection.AddTransfer(transfer)
|
|
err = scpCommand.getUploadFileData(2, transfer)
|
|
err = scpCommand.getUploadFileData(2, transfer)
|
|
assert.Error(t, err, "upload must fail, we have not enough data to read")
|
|
assert.Error(t, err, "upload must fail, we have not enough data to read")
|
|
@@ -1626,7 +1630,7 @@ func TestUploadError(t *testing.T) {
|
|
assert.NoError(t, err)
|
|
assert.NoError(t, err)
|
|
baseTransfer := common.NewBaseTransfer(file, connection.BaseConnection, nil, testfile,
|
|
baseTransfer := common.NewBaseTransfer(file, connection.BaseConnection, nil, testfile,
|
|
testfile, common.TransferUpload, 0, 0, 0, true, fs)
|
|
testfile, common.TransferUpload, 0, 0, 0, true, fs)
|
|
- transfer := newTransfer(baseTransfer, nil, nil)
|
|
|
|
|
|
+ transfer := newTransfer(baseTransfer, nil, nil, nil)
|
|
|
|
|
|
errFake := errors.New("fake error")
|
|
errFake := errors.New("fake error")
|
|
transfer.TransferError(errFake)
|
|
transfer.TransferError(errFake)
|
|
@@ -1645,6 +1649,49 @@ func TestUploadError(t *testing.T) {
|
|
common.Config.UploadMode = oldUploadMode
|
|
common.Config.UploadMode = oldUploadMode
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+func TestTransferFailingReader(t *testing.T) {
|
|
|
|
+ user := dataprovider.User{
|
|
|
|
+ Username: "testuser",
|
|
|
|
+ }
|
|
|
|
+ user.Permissions = make(map[string][]string)
|
|
|
|
+ user.Permissions["/"] = []string{dataprovider.PermAny}
|
|
|
|
+
|
|
|
|
+ fs := newMockOsFs(nil, nil, true, "", os.TempDir())
|
|
|
|
+ connection := &Connection{
|
|
|
|
+ BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user, fs),
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ request := sftp.NewRequest("Open", "afile.txt")
|
|
|
|
+ request.Flags = 27 // read,write,create,truncate
|
|
|
|
+
|
|
|
|
+ transfer, err := connection.handleFilewrite(request)
|
|
|
|
+ require.NoError(t, err)
|
|
|
|
+ buf := make([]byte, 32)
|
|
|
|
+ _, err = transfer.ReadAt(buf, 0)
|
|
|
|
+ assert.EqualError(t, err, sftp.ErrSSHFxOpUnsupported.Error())
|
|
|
|
+ if c, ok := transfer.(io.Closer); ok {
|
|
|
|
+ err = c.Close()
|
|
|
|
+ assert.EqualError(t, err, sftp.ErrSSHFxFailure.Error())
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ fsPath := filepath.Join(os.TempDir(), "afile.txt")
|
|
|
|
+
|
|
|
|
+ r, _, err := pipeat.Pipe()
|
|
|
|
+ assert.NoError(t, err)
|
|
|
|
+ baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, fsPath, filepath.Base(fsPath), common.TransferUpload, 0, 0, 0, false, fs)
|
|
|
|
+ errRead := errors.New("read is not allowed")
|
|
|
|
+ tr := newTransfer(baseTransfer, nil, r, errRead)
|
|
|
|
+ _, err = tr.ReadAt(buf, 0)
|
|
|
|
+ assert.EqualError(t, err, errRead.Error())
|
|
|
|
+
|
|
|
|
+ err = tr.Close()
|
|
|
|
+ assert.EqualError(t, err, sftp.ErrSSHFxFailure.Error())
|
|
|
|
+
|
|
|
|
+ err = os.Remove(fsPath)
|
|
|
|
+ assert.NoError(t, err)
|
|
|
|
+ assert.Len(t, connection.GetTransfers(), 0)
|
|
|
|
+}
|
|
|
|
+
|
|
func TestConnectionStatusStruct(t *testing.T) {
|
|
func TestConnectionStatusStruct(t *testing.T) {
|
|
var transfers []common.ConnectionTransfer
|
|
var transfers []common.ConnectionTransfer
|
|
transferUL := common.ConnectionTransfer{
|
|
transferUL := common.ConnectionTransfer{
|