123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674 |
- package sftpd
- import (
- "bytes"
- "fmt"
- "io"
- "io/ioutil"
- "os"
- "runtime"
- "testing"
- "time"
- "github.com/drakkan/sftpgo/dataprovider"
- "github.com/pkg/sftp"
- )
- type MockChannel struct {
- Buffer *bytes.Buffer
- StdErrBuffer *bytes.Buffer
- ReadError error
- WriteError error
- }
- func (c *MockChannel) Read(data []byte) (int, error) {
- if c.ReadError != nil {
- return 0, c.ReadError
- }
- return c.Buffer.Read(data)
- }
- func (c *MockChannel) Write(data []byte) (int, error) {
- if c.WriteError != nil {
- return 0, c.WriteError
- }
- return c.Buffer.Write(data)
- }
- func (c *MockChannel) Close() error {
- return nil
- }
- func (c *MockChannel) CloseWrite() error {
- return nil
- }
- func (c *MockChannel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) {
- return true, nil
- }
- func (c *MockChannel) Stderr() io.ReadWriter {
- return c.StdErrBuffer
- }
- func TestWrongActions(t *testing.T) {
- actionsCopy := actions
- badCommand := "/bad/command"
- if runtime.GOOS == "windows" {
- badCommand = "C:\\bad\\command"
- }
- actions = Actions{
- ExecuteOn: []string{operationDownload},
- Command: badCommand,
- HTTPNotificationURL: "",
- }
- err := executeAction(operationDownload, "username", "path", "")
- if err == nil {
- t.Errorf("action with bad command must fail")
- }
- err = executeAction(operationDelete, "username", "path", "")
- if err != nil {
- t.Errorf("action not configured must silently fail")
- }
- actions.Command = ""
- actions.HTTPNotificationURL = "http://foo\x7f.com/"
- err = executeAction(operationDownload, "username", "path", "")
- if err == nil {
- t.Errorf("action with bad url must fail")
- }
- actions = actionsCopy
- }
- func TestRemoveNonexistentTransfer(t *testing.T) {
- transfer := Transfer{}
- err := removeTransfer(&transfer)
- if err == nil {
- t.Errorf("remove nonexistent transfer must fail")
- }
- }
- func TestRemoveNonexistentQuotaScan(t *testing.T) {
- err := RemoveQuotaScan("username")
- if err == nil {
- t.Errorf("remove nonexistent transfer must fail")
- }
- }
- func TestGetOSOpenFlags(t *testing.T) {
- var flags sftp.FileOpenFlags
- flags.Write = true
- flags.Append = true
- flags.Excl = true
- osFlags := getOSOpenFlags(flags)
- if osFlags&os.O_WRONLY == 0 || osFlags&os.O_APPEND == 0 || osFlags&os.O_EXCL == 0 {
- t.Errorf("error getting os flags from sftp file open flags")
- }
- }
- func TestUploadResume(t *testing.T) {
- c := Connection{}
- var flags sftp.FileOpenFlags
- _, err := c.handleSFTPUploadToExistingFile(flags, "", "", 0)
- if err != sftp.ErrSshFxOpUnsupported {
- t.Errorf("file resume is not supported")
- }
- }
- func TestUploadFiles(t *testing.T) {
- oldUploadMode := uploadMode
- uploadMode = uploadModeAtomic
- c := Connection{}
- var flags sftp.FileOpenFlags
- flags.Write = true
- flags.Trunc = true
- _, err := c.handleSFTPUploadToExistingFile(flags, "missing_path", "other_missing_path", 0)
- if err == nil {
- t.Errorf("upload to existing file must fail if one or both paths are invalid")
- }
- uploadMode = uploadModeStandard
- _, err = c.handleSFTPUploadToExistingFile(flags, "missing_path", "other_missing_path", 0)
- if err == nil {
- t.Errorf("upload to existing file must fail if one or both paths are invalid")
- }
- missingFile := "missing/relative/file.txt"
- if runtime.GOOS == "windows" {
- missingFile = "missing\\relative\\file.txt"
- }
- _, err = c.handleSFTPUploadToNewFile(".", missingFile)
- if err == nil {
- t.Errorf("upload new file in missing path must fail")
- }
- uploadMode = oldUploadMode
- }
- func TestWithInvalidHome(t *testing.T) {
- u := dataprovider.User{}
- u.HomeDir = "home_rel_path"
- _, err := loginUser(u)
- if err == nil {
- t.Errorf("login a user with an invalid home_dir must fail")
- }
- c := Connection{
- User: u,
- }
- err = c.isSubDir("dir_rel_path")
- if err == nil {
- t.Errorf("tested path is not a home subdir")
- }
- }
- func TestSFTPCmdTargetPath(t *testing.T) {
- u := dataprovider.User{}
- u.HomeDir = "home_rel_path"
- u.Username = "test"
- u.Permissions = []string{"*"}
- connection := Connection{
- User: u,
- }
- _, err := connection.getSFTPCmdTargetPath("invalid_path")
- if err != sftp.ErrSshFxOpUnsupported {
- t.Errorf("getSFTPCmdTargetPath must fal with the expected error: %v", err)
- }
- }
- func TestSFTPGetUsedQuota(t *testing.T) {
- u := dataprovider.User{}
- u.HomeDir = "home_rel_path"
- u.Username = "test_invalid_user"
- u.QuotaSize = 4096
- u.QuotaFiles = 1
- u.Permissions = []string{"*"}
- connection := Connection{
- User: u,
- }
- res := connection.hasSpace(false)
- if res != false {
- t.Errorf("has space must return false if the user is invalid")
- }
- }
- func TestSCPFileMode(t *testing.T) {
- mode := getFileModeAsString(0, true)
- if mode != "0755" {
- t.Errorf("invalid file mode: %v expected: 0755", mode)
- }
- mode = getFileModeAsString(0700, true)
- if mode != "0700" {
- t.Errorf("invalid file mode: %v expected: 0700", mode)
- }
- mode = getFileModeAsString(0750, true)
- if mode != "0750" {
- t.Errorf("invalid file mode: %v expected: 0750", mode)
- }
- mode = getFileModeAsString(0777, true)
- if mode != "0777" {
- t.Errorf("invalid file mode: %v expected: 0777", mode)
- }
- mode = getFileModeAsString(0640, false)
- if mode != "0640" {
- t.Errorf("invalid file mode: %v expected: 0640", mode)
- }
- mode = getFileModeAsString(0600, false)
- if mode != "0600" {
- t.Errorf("invalid file mode: %v expected: 0600", mode)
- }
- mode = getFileModeAsString(0, false)
- if mode != "0644" {
- t.Errorf("invalid file mode: %v expected: 0644", mode)
- }
- fileMode := uint32(0777)
- fileMode = fileMode | uint32(os.ModeSetgid)
- fileMode = fileMode | uint32(os.ModeSetuid)
- fileMode = fileMode | uint32(os.ModeSticky)
- mode = getFileModeAsString(os.FileMode(fileMode), false)
- if mode != "7777" {
- t.Errorf("invalid file mode: %v expected: 7777", mode)
- }
- fileMode = uint32(0644)
- fileMode = fileMode | uint32(os.ModeSetgid)
- mode = getFileModeAsString(os.FileMode(fileMode), false)
- if mode != "4644" {
- t.Errorf("invalid file mode: %v expected: 4644", mode)
- }
- fileMode = uint32(0600)
- fileMode = fileMode | uint32(os.ModeSetuid)
- mode = getFileModeAsString(os.FileMode(fileMode), false)
- if mode != "2600" {
- t.Errorf("invalid file mode: %v expected: 2600", mode)
- }
- fileMode = uint32(0044)
- fileMode = fileMode | uint32(os.ModeSticky)
- mode = getFileModeAsString(os.FileMode(fileMode), false)
- if mode != "1044" {
- t.Errorf("invalid file mode: %v expected: 1044", mode)
- }
- }
- func TestSCPGetNonExistingDirContent(t *testing.T) {
- _, err := getDirContents("non_existing")
- if err == nil {
- t.Errorf("get non existing dir contents must fail")
- }
- }
- func TestSCPParseUploadMessage(t *testing.T) {
- connection := Connection{}
- buf := make([]byte, 65535)
- stdErrBuf := make([]byte, 65535)
- mockSSHChannel := MockChannel{
- Buffer: bytes.NewBuffer(buf),
- StdErrBuffer: bytes.NewBuffer(stdErrBuf),
- ReadError: nil,
- }
- scpCommand := scpCommand{
- connection: connection,
- args: []string{"-t", "/tmp"},
- channel: &mockSSHChannel,
- }
- _, _, err := scpCommand.parseUploadMessage("invalid")
- if err == nil {
- t.Errorf("parsing invalid upload message must fail")
- }
- _, _, err = scpCommand.parseUploadMessage("D0755 0")
- if err == nil {
- t.Errorf("parsing incomplete upload message must fail")
- }
- _, _, err = scpCommand.parseUploadMessage("D0755 invalidsize testdir")
- if err == nil {
- t.Errorf("parsing upload message with invalid size must fail")
- }
- _, _, err = scpCommand.parseUploadMessage("D0755 0 ")
- if err == nil {
- t.Errorf("parsing upload message with invalid name must fail")
- }
- }
- func TestSCPProtocolMessages(t *testing.T) {
- connection := Connection{}
- buf := make([]byte, 65535)
- stdErrBuf := make([]byte, 65535)
- readErr := fmt.Errorf("test read error")
- writeErr := fmt.Errorf("test write error")
- mockSSHChannel := MockChannel{
- Buffer: bytes.NewBuffer(buf),
- StdErrBuffer: bytes.NewBuffer(stdErrBuf),
- ReadError: readErr,
- WriteError: writeErr,
- }
- scpCommand := scpCommand{
- connection: connection,
- args: []string{"-t", "/tmp"},
- channel: &mockSSHChannel,
- }
- _, err := scpCommand.readProtocolMessage()
- if err == nil || err != readErr {
- t.Errorf("read protocol message must fail, we are sending a fake error")
- }
- err = scpCommand.sendConfirmationMessage()
- if err != writeErr {
- t.Errorf("write confirmation message must fail, we are sending a fake error")
- }
- err = scpCommand.sendProtocolMessage("E\n")
- if err != writeErr {
- t.Errorf("write confirmation message must fail, we are sending a fake error")
- }
- _, err = scpCommand.getNextUploadProtocolMessage()
- if err == nil || err != readErr {
- t.Errorf("read next upload protocol message must fail, we are sending a fake read error")
- }
- mockSSHChannel = MockChannel{
- Buffer: bytes.NewBuffer([]byte("T1183832947 0 1183833773 0\n")),
- StdErrBuffer: bytes.NewBuffer(stdErrBuf),
- ReadError: nil,
- WriteError: writeErr,
- }
- scpCommand.channel = &mockSSHChannel
- _, err = scpCommand.getNextUploadProtocolMessage()
- if err == nil || err != writeErr {
- t.Errorf("read next upload protocol message must fail, we are sending a fake write error")
- }
- respBuffer := []byte{0x02}
- protocolErrorMsg := "protocol error msg"
- respBuffer = append(respBuffer, protocolErrorMsg...)
- respBuffer = append(respBuffer, 0x0A)
- mockSSHChannel = MockChannel{
- Buffer: bytes.NewBuffer(respBuffer),
- StdErrBuffer: bytes.NewBuffer(stdErrBuf),
- ReadError: nil,
- WriteError: nil,
- }
- scpCommand.channel = &mockSSHChannel
- err = scpCommand.readConfirmationMessage()
- if err == nil || err.Error() != protocolErrorMsg {
- t.Errorf("read confirmation message must return the expected protocol error, actual err: %v", err)
- }
- }
- func TestSCPTestDownloadProtocolMessages(t *testing.T) {
- connection := Connection{}
- buf := make([]byte, 65535)
- stdErrBuf := make([]byte, 65535)
- readErr := fmt.Errorf("test read error")
- writeErr := fmt.Errorf("test write error")
- mockSSHChannel := MockChannel{
- Buffer: bytes.NewBuffer(buf),
- StdErrBuffer: bytes.NewBuffer(stdErrBuf),
- ReadError: readErr,
- WriteError: writeErr,
- }
- scpCommand := scpCommand{
- connection: connection,
- args: []string{"-f", "-p", "/tmp"},
- channel: &mockSSHChannel,
- }
- path := "testDir"
- os.Mkdir(path, 0777)
- stat, _ := os.Stat(path)
- err := scpCommand.sendDownloadProtocolMessages(path, stat)
- if err != writeErr {
- t.Errorf("sendDownloadProtocolMessages must return the expected error: %v", err)
- }
- mockSSHChannel = MockChannel{
- Buffer: bytes.NewBuffer(buf),
- StdErrBuffer: bytes.NewBuffer(stdErrBuf),
- ReadError: readErr,
- WriteError: nil,
- }
- err = scpCommand.sendDownloadProtocolMessages(path, stat)
- if err != readErr {
- t.Errorf("sendDownloadProtocolMessages must return the expected error: %v", err)
- }
- mockSSHChannel = MockChannel{
- Buffer: bytes.NewBuffer(buf),
- StdErrBuffer: bytes.NewBuffer(stdErrBuf),
- ReadError: readErr,
- WriteError: writeErr,
- }
- scpCommand.args = []string{"-f", "/tmp"}
- scpCommand.channel = &mockSSHChannel
- err = scpCommand.sendDownloadProtocolMessages(path, stat)
- if err != writeErr {
- t.Errorf("sendDownloadProtocolMessages must return the expected error: %v", err)
- }
- mockSSHChannel = MockChannel{
- Buffer: bytes.NewBuffer(buf),
- StdErrBuffer: bytes.NewBuffer(stdErrBuf),
- ReadError: readErr,
- WriteError: nil,
- }
- scpCommand.channel = &mockSSHChannel
- err = scpCommand.sendDownloadProtocolMessages(path, stat)
- if err != readErr {
- t.Errorf("sendDownloadProtocolMessages must return the expected error: %v", err)
- }
- os.Remove(path)
- }
- func TestSCPCommandHandleErrors(t *testing.T) {
- connection := Connection{}
- buf := make([]byte, 65535)
- stdErrBuf := make([]byte, 65535)
- readErr := fmt.Errorf("test read error")
- writeErr := fmt.Errorf("test write error")
- mockSSHChannel := MockChannel{
- Buffer: bytes.NewBuffer(buf),
- StdErrBuffer: bytes.NewBuffer(stdErrBuf),
- ReadError: readErr,
- WriteError: writeErr,
- }
- scpCommand := scpCommand{
- connection: connection,
- args: []string{"-f", "/tmp"},
- channel: &mockSSHChannel,
- }
- err := scpCommand.handle()
- if err == nil || err != readErr {
- t.Errorf("scp download must fail, we are sending a fake error")
- }
- scpCommand.args = []string{"-i", "/tmp"}
- err = scpCommand.handle()
- if err == nil {
- t.Errorf("invalid scp command must fail")
- }
- }
- func TestRecursiveDownloadErrors(t *testing.T) {
- connection := Connection{}
- buf := make([]byte, 65535)
- stdErrBuf := make([]byte, 65535)
- readErr := fmt.Errorf("test read error")
- writeErr := fmt.Errorf("test write error")
- mockSSHChannel := MockChannel{
- Buffer: bytes.NewBuffer(buf),
- StdErrBuffer: bytes.NewBuffer(stdErrBuf),
- ReadError: readErr,
- WriteError: writeErr,
- }
- scpCommand := scpCommand{
- connection: connection,
- args: []string{"-r", "-f", "/tmp"},
- channel: &mockSSHChannel,
- }
- path := "testDir"
- os.Mkdir(path, 0777)
- stat, _ := os.Stat(path)
- err := scpCommand.handleRecursiveDownload("invalid_dir", stat)
- if err != writeErr {
- t.Errorf("recursive upload download must fail with the expected error: %v", err)
- }
- mockSSHChannel = MockChannel{
- Buffer: bytes.NewBuffer(buf),
- StdErrBuffer: bytes.NewBuffer(stdErrBuf),
- ReadError: nil,
- WriteError: nil,
- }
- scpCommand.channel = &mockSSHChannel
- err = scpCommand.handleRecursiveDownload("invalid_dir", stat)
- if err == nil {
- t.Errorf("recursive upload download must fail for a non existing dir")
- }
- os.Remove(path)
- }
- func TestRecursiveUploadErrors(t *testing.T) {
- connection := Connection{}
- buf := make([]byte, 65535)
- stdErrBuf := make([]byte, 65535)
- readErr := fmt.Errorf("test read error")
- writeErr := fmt.Errorf("test write error")
- mockSSHChannel := MockChannel{
- Buffer: bytes.NewBuffer(buf),
- StdErrBuffer: bytes.NewBuffer(stdErrBuf),
- ReadError: readErr,
- WriteError: writeErr,
- }
- scpCommand := scpCommand{
- connection: connection,
- args: []string{"-r", "-t", "/tmp"},
- channel: &mockSSHChannel,
- }
- err := scpCommand.handleRecursiveUpload()
- if err == nil {
- t.Errorf("recursive upload must fail, we send a fake error message")
- }
- mockSSHChannel = MockChannel{
- Buffer: bytes.NewBuffer(buf),
- StdErrBuffer: bytes.NewBuffer(stdErrBuf),
- ReadError: readErr,
- WriteError: nil,
- }
- scpCommand.channel = &mockSSHChannel
- err = scpCommand.handleRecursiveUpload()
- if err == nil {
- t.Errorf("recursive upload must fail, we send a fake error message")
- }
- }
- func TestSCPCreateDirs(t *testing.T) {
- buf := make([]byte, 65535)
- stdErrBuf := make([]byte, 65535)
- u := dataprovider.User{}
- u.HomeDir = "home_rel_path"
- u.Username = "test"
- u.Permissions = []string{"*"}
- connection := Connection{
- User: u,
- }
- mockSSHChannel := MockChannel{
- Buffer: bytes.NewBuffer(buf),
- StdErrBuffer: bytes.NewBuffer(stdErrBuf),
- ReadError: nil,
- WriteError: nil,
- }
- scpCommand := scpCommand{
- connection: connection,
- args: []string{"-r", "-t", "/tmp"},
- channel: &mockSSHChannel,
- }
- err := scpCommand.handleCreateDir("invalid_dir")
- if err == nil {
- t.Errorf("create invalid dir must fail")
- }
- }
- func TestSCPDownloadFileData(t *testing.T) {
- testfile := "testfile"
- buf := make([]byte, 65535)
- readErr := fmt.Errorf("test read error")
- writeErr := fmt.Errorf("test write error")
- stdErrBuf := make([]byte, 65535)
- connection := Connection{}
- mockSSHChannelReadErr := MockChannel{
- Buffer: bytes.NewBuffer(buf),
- StdErrBuffer: bytes.NewBuffer(stdErrBuf),
- ReadError: readErr,
- WriteError: nil,
- }
- mockSSHChannelWriteErr := MockChannel{
- Buffer: bytes.NewBuffer(buf),
- StdErrBuffer: bytes.NewBuffer(stdErrBuf),
- ReadError: nil,
- WriteError: writeErr,
- }
- scpCommand := scpCommand{
- connection: connection,
- args: []string{"-r", "-f", "/tmp"},
- channel: &mockSSHChannelReadErr,
- }
- ioutil.WriteFile(testfile, []byte("test"), 0666)
- stat, _ := os.Stat(testfile)
- err := scpCommand.sendDownloadFileData(testfile, stat, nil)
- if err != readErr {
- t.Errorf("send download file data must fail with the expected error: %v", err)
- }
- scpCommand.channel = &mockSSHChannelWriteErr
- err = scpCommand.sendDownloadFileData(testfile, stat, nil)
- if err != writeErr {
- t.Errorf("send download file data must fail with the expected error: %v", err)
- }
- scpCommand.args = []string{"-r", "-p", "-f", "/tmp"}
- err = scpCommand.sendDownloadFileData(testfile, stat, nil)
- if err != writeErr {
- t.Errorf("send download file data must fail with the expected error: %v", err)
- }
- scpCommand.channel = &mockSSHChannelReadErr
- err = scpCommand.sendDownloadFileData(testfile, stat, nil)
- if err != readErr {
- t.Errorf("send download file data must fail with the expected error: %v", err)
- }
- os.Remove(testfile)
- }
- func TestSCPUploadFiledata(t *testing.T) {
- testfile := "testfile"
- connection := Connection{
- User: dataprovider.User{
- Username: "testuser",
- },
- protocol: protocolSCP,
- }
- buf := make([]byte, 65535)
- stdErrBuf := make([]byte, 65535)
- readErr := fmt.Errorf("test read error")
- writeErr := fmt.Errorf("test write error")
- mockSSHChannel := MockChannel{
- Buffer: bytes.NewBuffer(buf),
- StdErrBuffer: bytes.NewBuffer(stdErrBuf),
- ReadError: readErr,
- WriteError: writeErr,
- }
- scpCommand := scpCommand{
- connection: connection,
- args: []string{"-r", "-t", "/tmp"},
- channel: &mockSSHChannel,
- }
- file, _ := os.Create(testfile)
- transfer := Transfer{
- file: file,
- path: file.Name(),
- start: time.Now(),
- bytesSent: 0,
- bytesReceived: 0,
- user: scpCommand.connection.User,
- connectionID: "",
- transferType: transferDownload,
- lastActivity: time.Now(),
- isNewFile: true,
- protocol: connection.protocol,
- }
- addTransfer(&transfer)
- err := scpCommand.getUploadFileData(2, &transfer)
- if err == nil {
- t.Errorf("upload must fail, we send a fake write error message")
- }
- mockSSHChannel = MockChannel{
- Buffer: bytes.NewBuffer(buf),
- StdErrBuffer: bytes.NewBuffer(stdErrBuf),
- ReadError: readErr,
- WriteError: nil,
- }
- scpCommand.channel = &mockSSHChannel
- file, _ = os.Create(testfile)
- transfer.file = file
- addTransfer(&transfer)
- err = scpCommand.getUploadFileData(2, &transfer)
- if err == nil {
- t.Errorf("upload must fail, we send a fake read error message")
- }
- respBuffer := []byte("12")
- respBuffer = append(respBuffer, 0x02)
- mockSSHChannel = MockChannel{
- Buffer: bytes.NewBuffer(respBuffer),
- StdErrBuffer: bytes.NewBuffer(stdErrBuf),
- ReadError: nil,
- WriteError: nil,
- }
- scpCommand.channel = &mockSSHChannel
- file, _ = os.Create(testfile)
- transfer.file = file
- addTransfer(&transfer)
- err = scpCommand.getUploadFileData(2, &transfer)
- if err == nil {
- t.Errorf("upload must fail, we have not enough data to read")
- }
- // the file is already closed so we have an error on trasfer closing
- mockSSHChannel = MockChannel{
- Buffer: bytes.NewBuffer(buf),
- StdErrBuffer: bytes.NewBuffer(stdErrBuf),
- ReadError: nil,
- WriteError: nil,
- }
- addTransfer(&transfer)
- err = scpCommand.getUploadFileData(0, &transfer)
- if err == nil {
- t.Errorf("upload must fail, the file is closed")
- }
- os.Remove(testfile)
- }
|