internal_test.go 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674
  1. package sftpd
  2. import (
  3. "bytes"
  4. "fmt"
  5. "io"
  6. "io/ioutil"
  7. "os"
  8. "runtime"
  9. "testing"
  10. "time"
  11. "github.com/drakkan/sftpgo/dataprovider"
  12. "github.com/pkg/sftp"
  13. )
  14. type MockChannel struct {
  15. Buffer *bytes.Buffer
  16. StdErrBuffer *bytes.Buffer
  17. ReadError error
  18. WriteError error
  19. }
  20. func (c *MockChannel) Read(data []byte) (int, error) {
  21. if c.ReadError != nil {
  22. return 0, c.ReadError
  23. }
  24. return c.Buffer.Read(data)
  25. }
  26. func (c *MockChannel) Write(data []byte) (int, error) {
  27. if c.WriteError != nil {
  28. return 0, c.WriteError
  29. }
  30. return c.Buffer.Write(data)
  31. }
  32. func (c *MockChannel) Close() error {
  33. return nil
  34. }
  35. func (c *MockChannel) CloseWrite() error {
  36. return nil
  37. }
  38. func (c *MockChannel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) {
  39. return true, nil
  40. }
  41. func (c *MockChannel) Stderr() io.ReadWriter {
  42. return c.StdErrBuffer
  43. }
  44. func TestWrongActions(t *testing.T) {
  45. actionsCopy := actions
  46. badCommand := "/bad/command"
  47. if runtime.GOOS == "windows" {
  48. badCommand = "C:\\bad\\command"
  49. }
  50. actions = Actions{
  51. ExecuteOn: []string{operationDownload},
  52. Command: badCommand,
  53. HTTPNotificationURL: "",
  54. }
  55. err := executeAction(operationDownload, "username", "path", "")
  56. if err == nil {
  57. t.Errorf("action with bad command must fail")
  58. }
  59. err = executeAction(operationDelete, "username", "path", "")
  60. if err != nil {
  61. t.Errorf("action not configured must silently fail")
  62. }
  63. actions.Command = ""
  64. actions.HTTPNotificationURL = "http://foo\x7f.com/"
  65. err = executeAction(operationDownload, "username", "path", "")
  66. if err == nil {
  67. t.Errorf("action with bad url must fail")
  68. }
  69. actions = actionsCopy
  70. }
  71. func TestRemoveNonexistentTransfer(t *testing.T) {
  72. transfer := Transfer{}
  73. err := removeTransfer(&transfer)
  74. if err == nil {
  75. t.Errorf("remove nonexistent transfer must fail")
  76. }
  77. }
  78. func TestRemoveNonexistentQuotaScan(t *testing.T) {
  79. err := RemoveQuotaScan("username")
  80. if err == nil {
  81. t.Errorf("remove nonexistent transfer must fail")
  82. }
  83. }
  84. func TestGetOSOpenFlags(t *testing.T) {
  85. var flags sftp.FileOpenFlags
  86. flags.Write = true
  87. flags.Append = true
  88. flags.Excl = true
  89. osFlags := getOSOpenFlags(flags)
  90. if osFlags&os.O_WRONLY == 0 || osFlags&os.O_APPEND == 0 || osFlags&os.O_EXCL == 0 {
  91. t.Errorf("error getting os flags from sftp file open flags")
  92. }
  93. }
  94. func TestUploadResume(t *testing.T) {
  95. c := Connection{}
  96. var flags sftp.FileOpenFlags
  97. _, err := c.handleSFTPUploadToExistingFile(flags, "", "", 0)
  98. if err != sftp.ErrSshFxOpUnsupported {
  99. t.Errorf("file resume is not supported")
  100. }
  101. }
  102. func TestUploadFiles(t *testing.T) {
  103. oldUploadMode := uploadMode
  104. uploadMode = uploadModeAtomic
  105. c := Connection{}
  106. var flags sftp.FileOpenFlags
  107. flags.Write = true
  108. flags.Trunc = true
  109. _, err := c.handleSFTPUploadToExistingFile(flags, "missing_path", "other_missing_path", 0)
  110. if err == nil {
  111. t.Errorf("upload to existing file must fail if one or both paths are invalid")
  112. }
  113. uploadMode = uploadModeStandard
  114. _, err = c.handleSFTPUploadToExistingFile(flags, "missing_path", "other_missing_path", 0)
  115. if err == nil {
  116. t.Errorf("upload to existing file must fail if one or both paths are invalid")
  117. }
  118. missingFile := "missing/relative/file.txt"
  119. if runtime.GOOS == "windows" {
  120. missingFile = "missing\\relative\\file.txt"
  121. }
  122. _, err = c.handleSFTPUploadToNewFile(".", missingFile)
  123. if err == nil {
  124. t.Errorf("upload new file in missing path must fail")
  125. }
  126. uploadMode = oldUploadMode
  127. }
  128. func TestWithInvalidHome(t *testing.T) {
  129. u := dataprovider.User{}
  130. u.HomeDir = "home_rel_path"
  131. _, err := loginUser(u)
  132. if err == nil {
  133. t.Errorf("login a user with an invalid home_dir must fail")
  134. }
  135. c := Connection{
  136. User: u,
  137. }
  138. err = c.isSubDir("dir_rel_path")
  139. if err == nil {
  140. t.Errorf("tested path is not a home subdir")
  141. }
  142. }
  143. func TestSFTPCmdTargetPath(t *testing.T) {
  144. u := dataprovider.User{}
  145. u.HomeDir = "home_rel_path"
  146. u.Username = "test"
  147. u.Permissions = []string{"*"}
  148. connection := Connection{
  149. User: u,
  150. }
  151. _, err := connection.getSFTPCmdTargetPath("invalid_path")
  152. if err != sftp.ErrSshFxOpUnsupported {
  153. t.Errorf("getSFTPCmdTargetPath must fal with the expected error: %v", err)
  154. }
  155. }
  156. func TestSFTPGetUsedQuota(t *testing.T) {
  157. u := dataprovider.User{}
  158. u.HomeDir = "home_rel_path"
  159. u.Username = "test_invalid_user"
  160. u.QuotaSize = 4096
  161. u.QuotaFiles = 1
  162. u.Permissions = []string{"*"}
  163. connection := Connection{
  164. User: u,
  165. }
  166. res := connection.hasSpace(false)
  167. if res != false {
  168. t.Errorf("has space must return false if the user is invalid")
  169. }
  170. }
  171. func TestSCPFileMode(t *testing.T) {
  172. mode := getFileModeAsString(0, true)
  173. if mode != "0755" {
  174. t.Errorf("invalid file mode: %v expected: 0755", mode)
  175. }
  176. mode = getFileModeAsString(0700, true)
  177. if mode != "0700" {
  178. t.Errorf("invalid file mode: %v expected: 0700", mode)
  179. }
  180. mode = getFileModeAsString(0750, true)
  181. if mode != "0750" {
  182. t.Errorf("invalid file mode: %v expected: 0750", mode)
  183. }
  184. mode = getFileModeAsString(0777, true)
  185. if mode != "0777" {
  186. t.Errorf("invalid file mode: %v expected: 0777", mode)
  187. }
  188. mode = getFileModeAsString(0640, false)
  189. if mode != "0640" {
  190. t.Errorf("invalid file mode: %v expected: 0640", mode)
  191. }
  192. mode = getFileModeAsString(0600, false)
  193. if mode != "0600" {
  194. t.Errorf("invalid file mode: %v expected: 0600", mode)
  195. }
  196. mode = getFileModeAsString(0, false)
  197. if mode != "0644" {
  198. t.Errorf("invalid file mode: %v expected: 0644", mode)
  199. }
  200. fileMode := uint32(0777)
  201. fileMode = fileMode | uint32(os.ModeSetgid)
  202. fileMode = fileMode | uint32(os.ModeSetuid)
  203. fileMode = fileMode | uint32(os.ModeSticky)
  204. mode = getFileModeAsString(os.FileMode(fileMode), false)
  205. if mode != "7777" {
  206. t.Errorf("invalid file mode: %v expected: 7777", mode)
  207. }
  208. fileMode = uint32(0644)
  209. fileMode = fileMode | uint32(os.ModeSetgid)
  210. mode = getFileModeAsString(os.FileMode(fileMode), false)
  211. if mode != "4644" {
  212. t.Errorf("invalid file mode: %v expected: 4644", mode)
  213. }
  214. fileMode = uint32(0600)
  215. fileMode = fileMode | uint32(os.ModeSetuid)
  216. mode = getFileModeAsString(os.FileMode(fileMode), false)
  217. if mode != "2600" {
  218. t.Errorf("invalid file mode: %v expected: 2600", mode)
  219. }
  220. fileMode = uint32(0044)
  221. fileMode = fileMode | uint32(os.ModeSticky)
  222. mode = getFileModeAsString(os.FileMode(fileMode), false)
  223. if mode != "1044" {
  224. t.Errorf("invalid file mode: %v expected: 1044", mode)
  225. }
  226. }
  227. func TestSCPGetNonExistingDirContent(t *testing.T) {
  228. _, err := getDirContents("non_existing")
  229. if err == nil {
  230. t.Errorf("get non existing dir contents must fail")
  231. }
  232. }
  233. func TestSCPParseUploadMessage(t *testing.T) {
  234. connection := Connection{}
  235. buf := make([]byte, 65535)
  236. stdErrBuf := make([]byte, 65535)
  237. mockSSHChannel := MockChannel{
  238. Buffer: bytes.NewBuffer(buf),
  239. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  240. ReadError: nil,
  241. }
  242. scpCommand := scpCommand{
  243. connection: connection,
  244. args: []string{"-t", "/tmp"},
  245. channel: &mockSSHChannel,
  246. }
  247. _, _, err := scpCommand.parseUploadMessage("invalid")
  248. if err == nil {
  249. t.Errorf("parsing invalid upload message must fail")
  250. }
  251. _, _, err = scpCommand.parseUploadMessage("D0755 0")
  252. if err == nil {
  253. t.Errorf("parsing incomplete upload message must fail")
  254. }
  255. _, _, err = scpCommand.parseUploadMessage("D0755 invalidsize testdir")
  256. if err == nil {
  257. t.Errorf("parsing upload message with invalid size must fail")
  258. }
  259. _, _, err = scpCommand.parseUploadMessage("D0755 0 ")
  260. if err == nil {
  261. t.Errorf("parsing upload message with invalid name must fail")
  262. }
  263. }
  264. func TestSCPProtocolMessages(t *testing.T) {
  265. connection := Connection{}
  266. buf := make([]byte, 65535)
  267. stdErrBuf := make([]byte, 65535)
  268. readErr := fmt.Errorf("test read error")
  269. writeErr := fmt.Errorf("test write error")
  270. mockSSHChannel := MockChannel{
  271. Buffer: bytes.NewBuffer(buf),
  272. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  273. ReadError: readErr,
  274. WriteError: writeErr,
  275. }
  276. scpCommand := scpCommand{
  277. connection: connection,
  278. args: []string{"-t", "/tmp"},
  279. channel: &mockSSHChannel,
  280. }
  281. _, err := scpCommand.readProtocolMessage()
  282. if err == nil || err != readErr {
  283. t.Errorf("read protocol message must fail, we are sending a fake error")
  284. }
  285. err = scpCommand.sendConfirmationMessage()
  286. if err != writeErr {
  287. t.Errorf("write confirmation message must fail, we are sending a fake error")
  288. }
  289. err = scpCommand.sendProtocolMessage("E\n")
  290. if err != writeErr {
  291. t.Errorf("write confirmation message must fail, we are sending a fake error")
  292. }
  293. _, err = scpCommand.getNextUploadProtocolMessage()
  294. if err == nil || err != readErr {
  295. t.Errorf("read next upload protocol message must fail, we are sending a fake read error")
  296. }
  297. mockSSHChannel = MockChannel{
  298. Buffer: bytes.NewBuffer([]byte("T1183832947 0 1183833773 0\n")),
  299. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  300. ReadError: nil,
  301. WriteError: writeErr,
  302. }
  303. scpCommand.channel = &mockSSHChannel
  304. _, err = scpCommand.getNextUploadProtocolMessage()
  305. if err == nil || err != writeErr {
  306. t.Errorf("read next upload protocol message must fail, we are sending a fake write error")
  307. }
  308. respBuffer := []byte{0x02}
  309. protocolErrorMsg := "protocol error msg"
  310. respBuffer = append(respBuffer, protocolErrorMsg...)
  311. respBuffer = append(respBuffer, 0x0A)
  312. mockSSHChannel = MockChannel{
  313. Buffer: bytes.NewBuffer(respBuffer),
  314. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  315. ReadError: nil,
  316. WriteError: nil,
  317. }
  318. scpCommand.channel = &mockSSHChannel
  319. err = scpCommand.readConfirmationMessage()
  320. if err == nil || err.Error() != protocolErrorMsg {
  321. t.Errorf("read confirmation message must return the expected protocol error, actual err: %v", err)
  322. }
  323. }
  324. func TestSCPTestDownloadProtocolMessages(t *testing.T) {
  325. connection := Connection{}
  326. buf := make([]byte, 65535)
  327. stdErrBuf := make([]byte, 65535)
  328. readErr := fmt.Errorf("test read error")
  329. writeErr := fmt.Errorf("test write error")
  330. mockSSHChannel := MockChannel{
  331. Buffer: bytes.NewBuffer(buf),
  332. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  333. ReadError: readErr,
  334. WriteError: writeErr,
  335. }
  336. scpCommand := scpCommand{
  337. connection: connection,
  338. args: []string{"-f", "-p", "/tmp"},
  339. channel: &mockSSHChannel,
  340. }
  341. path := "testDir"
  342. os.Mkdir(path, 0777)
  343. stat, _ := os.Stat(path)
  344. err := scpCommand.sendDownloadProtocolMessages(path, stat)
  345. if err != writeErr {
  346. t.Errorf("sendDownloadProtocolMessages must return the expected error: %v", err)
  347. }
  348. mockSSHChannel = MockChannel{
  349. Buffer: bytes.NewBuffer(buf),
  350. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  351. ReadError: readErr,
  352. WriteError: nil,
  353. }
  354. err = scpCommand.sendDownloadProtocolMessages(path, stat)
  355. if err != readErr {
  356. t.Errorf("sendDownloadProtocolMessages must return the expected error: %v", err)
  357. }
  358. mockSSHChannel = MockChannel{
  359. Buffer: bytes.NewBuffer(buf),
  360. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  361. ReadError: readErr,
  362. WriteError: writeErr,
  363. }
  364. scpCommand.args = []string{"-f", "/tmp"}
  365. scpCommand.channel = &mockSSHChannel
  366. err = scpCommand.sendDownloadProtocolMessages(path, stat)
  367. if err != writeErr {
  368. t.Errorf("sendDownloadProtocolMessages must return the expected error: %v", err)
  369. }
  370. mockSSHChannel = MockChannel{
  371. Buffer: bytes.NewBuffer(buf),
  372. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  373. ReadError: readErr,
  374. WriteError: nil,
  375. }
  376. scpCommand.channel = &mockSSHChannel
  377. err = scpCommand.sendDownloadProtocolMessages(path, stat)
  378. if err != readErr {
  379. t.Errorf("sendDownloadProtocolMessages must return the expected error: %v", err)
  380. }
  381. os.Remove(path)
  382. }
  383. func TestSCPCommandHandleErrors(t *testing.T) {
  384. connection := Connection{}
  385. buf := make([]byte, 65535)
  386. stdErrBuf := make([]byte, 65535)
  387. readErr := fmt.Errorf("test read error")
  388. writeErr := fmt.Errorf("test write error")
  389. mockSSHChannel := MockChannel{
  390. Buffer: bytes.NewBuffer(buf),
  391. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  392. ReadError: readErr,
  393. WriteError: writeErr,
  394. }
  395. scpCommand := scpCommand{
  396. connection: connection,
  397. args: []string{"-f", "/tmp"},
  398. channel: &mockSSHChannel,
  399. }
  400. err := scpCommand.handle()
  401. if err == nil || err != readErr {
  402. t.Errorf("scp download must fail, we are sending a fake error")
  403. }
  404. scpCommand.args = []string{"-i", "/tmp"}
  405. err = scpCommand.handle()
  406. if err == nil {
  407. t.Errorf("invalid scp command must fail")
  408. }
  409. }
  410. func TestRecursiveDownloadErrors(t *testing.T) {
  411. connection := Connection{}
  412. buf := make([]byte, 65535)
  413. stdErrBuf := make([]byte, 65535)
  414. readErr := fmt.Errorf("test read error")
  415. writeErr := fmt.Errorf("test write error")
  416. mockSSHChannel := MockChannel{
  417. Buffer: bytes.NewBuffer(buf),
  418. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  419. ReadError: readErr,
  420. WriteError: writeErr,
  421. }
  422. scpCommand := scpCommand{
  423. connection: connection,
  424. args: []string{"-r", "-f", "/tmp"},
  425. channel: &mockSSHChannel,
  426. }
  427. path := "testDir"
  428. os.Mkdir(path, 0777)
  429. stat, _ := os.Stat(path)
  430. err := scpCommand.handleRecursiveDownload("invalid_dir", stat)
  431. if err != writeErr {
  432. t.Errorf("recursive upload download must fail with the expected error: %v", err)
  433. }
  434. mockSSHChannel = MockChannel{
  435. Buffer: bytes.NewBuffer(buf),
  436. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  437. ReadError: nil,
  438. WriteError: nil,
  439. }
  440. scpCommand.channel = &mockSSHChannel
  441. err = scpCommand.handleRecursiveDownload("invalid_dir", stat)
  442. if err == nil {
  443. t.Errorf("recursive upload download must fail for a non existing dir")
  444. }
  445. os.Remove(path)
  446. }
  447. func TestRecursiveUploadErrors(t *testing.T) {
  448. connection := Connection{}
  449. buf := make([]byte, 65535)
  450. stdErrBuf := make([]byte, 65535)
  451. readErr := fmt.Errorf("test read error")
  452. writeErr := fmt.Errorf("test write error")
  453. mockSSHChannel := MockChannel{
  454. Buffer: bytes.NewBuffer(buf),
  455. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  456. ReadError: readErr,
  457. WriteError: writeErr,
  458. }
  459. scpCommand := scpCommand{
  460. connection: connection,
  461. args: []string{"-r", "-t", "/tmp"},
  462. channel: &mockSSHChannel,
  463. }
  464. err := scpCommand.handleRecursiveUpload()
  465. if err == nil {
  466. t.Errorf("recursive upload must fail, we send a fake error message")
  467. }
  468. mockSSHChannel = MockChannel{
  469. Buffer: bytes.NewBuffer(buf),
  470. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  471. ReadError: readErr,
  472. WriteError: nil,
  473. }
  474. scpCommand.channel = &mockSSHChannel
  475. err = scpCommand.handleRecursiveUpload()
  476. if err == nil {
  477. t.Errorf("recursive upload must fail, we send a fake error message")
  478. }
  479. }
  480. func TestSCPCreateDirs(t *testing.T) {
  481. buf := make([]byte, 65535)
  482. stdErrBuf := make([]byte, 65535)
  483. u := dataprovider.User{}
  484. u.HomeDir = "home_rel_path"
  485. u.Username = "test"
  486. u.Permissions = []string{"*"}
  487. connection := Connection{
  488. User: u,
  489. }
  490. mockSSHChannel := MockChannel{
  491. Buffer: bytes.NewBuffer(buf),
  492. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  493. ReadError: nil,
  494. WriteError: nil,
  495. }
  496. scpCommand := scpCommand{
  497. connection: connection,
  498. args: []string{"-r", "-t", "/tmp"},
  499. channel: &mockSSHChannel,
  500. }
  501. err := scpCommand.handleCreateDir("invalid_dir")
  502. if err == nil {
  503. t.Errorf("create invalid dir must fail")
  504. }
  505. }
  506. func TestSCPDownloadFileData(t *testing.T) {
  507. testfile := "testfile"
  508. buf := make([]byte, 65535)
  509. readErr := fmt.Errorf("test read error")
  510. writeErr := fmt.Errorf("test write error")
  511. stdErrBuf := make([]byte, 65535)
  512. connection := Connection{}
  513. mockSSHChannelReadErr := MockChannel{
  514. Buffer: bytes.NewBuffer(buf),
  515. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  516. ReadError: readErr,
  517. WriteError: nil,
  518. }
  519. mockSSHChannelWriteErr := MockChannel{
  520. Buffer: bytes.NewBuffer(buf),
  521. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  522. ReadError: nil,
  523. WriteError: writeErr,
  524. }
  525. scpCommand := scpCommand{
  526. connection: connection,
  527. args: []string{"-r", "-f", "/tmp"},
  528. channel: &mockSSHChannelReadErr,
  529. }
  530. ioutil.WriteFile(testfile, []byte("test"), 0666)
  531. stat, _ := os.Stat(testfile)
  532. err := scpCommand.sendDownloadFileData(testfile, stat, nil)
  533. if err != readErr {
  534. t.Errorf("send download file data must fail with the expected error: %v", err)
  535. }
  536. scpCommand.channel = &mockSSHChannelWriteErr
  537. err = scpCommand.sendDownloadFileData(testfile, stat, nil)
  538. if err != writeErr {
  539. t.Errorf("send download file data must fail with the expected error: %v", err)
  540. }
  541. scpCommand.args = []string{"-r", "-p", "-f", "/tmp"}
  542. err = scpCommand.sendDownloadFileData(testfile, stat, nil)
  543. if err != writeErr {
  544. t.Errorf("send download file data must fail with the expected error: %v", err)
  545. }
  546. scpCommand.channel = &mockSSHChannelReadErr
  547. err = scpCommand.sendDownloadFileData(testfile, stat, nil)
  548. if err != readErr {
  549. t.Errorf("send download file data must fail with the expected error: %v", err)
  550. }
  551. os.Remove(testfile)
  552. }
  553. func TestSCPUploadFiledata(t *testing.T) {
  554. testfile := "testfile"
  555. connection := Connection{
  556. User: dataprovider.User{
  557. Username: "testuser",
  558. },
  559. protocol: protocolSCP,
  560. }
  561. buf := make([]byte, 65535)
  562. stdErrBuf := make([]byte, 65535)
  563. readErr := fmt.Errorf("test read error")
  564. writeErr := fmt.Errorf("test write error")
  565. mockSSHChannel := MockChannel{
  566. Buffer: bytes.NewBuffer(buf),
  567. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  568. ReadError: readErr,
  569. WriteError: writeErr,
  570. }
  571. scpCommand := scpCommand{
  572. connection: connection,
  573. args: []string{"-r", "-t", "/tmp"},
  574. channel: &mockSSHChannel,
  575. }
  576. file, _ := os.Create(testfile)
  577. transfer := Transfer{
  578. file: file,
  579. path: file.Name(),
  580. start: time.Now(),
  581. bytesSent: 0,
  582. bytesReceived: 0,
  583. user: scpCommand.connection.User,
  584. connectionID: "",
  585. transferType: transferDownload,
  586. lastActivity: time.Now(),
  587. isNewFile: true,
  588. protocol: connection.protocol,
  589. }
  590. addTransfer(&transfer)
  591. err := scpCommand.getUploadFileData(2, &transfer)
  592. if err == nil {
  593. t.Errorf("upload must fail, we send a fake write error message")
  594. }
  595. mockSSHChannel = MockChannel{
  596. Buffer: bytes.NewBuffer(buf),
  597. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  598. ReadError: readErr,
  599. WriteError: nil,
  600. }
  601. scpCommand.channel = &mockSSHChannel
  602. file, _ = os.Create(testfile)
  603. transfer.file = file
  604. addTransfer(&transfer)
  605. err = scpCommand.getUploadFileData(2, &transfer)
  606. if err == nil {
  607. t.Errorf("upload must fail, we send a fake read error message")
  608. }
  609. respBuffer := []byte("12")
  610. respBuffer = append(respBuffer, 0x02)
  611. mockSSHChannel = MockChannel{
  612. Buffer: bytes.NewBuffer(respBuffer),
  613. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  614. ReadError: nil,
  615. WriteError: nil,
  616. }
  617. scpCommand.channel = &mockSSHChannel
  618. file, _ = os.Create(testfile)
  619. transfer.file = file
  620. addTransfer(&transfer)
  621. err = scpCommand.getUploadFileData(2, &transfer)
  622. if err == nil {
  623. t.Errorf("upload must fail, we have not enough data to read")
  624. }
  625. // the file is already closed so we have an error on trasfer closing
  626. mockSSHChannel = MockChannel{
  627. Buffer: bytes.NewBuffer(buf),
  628. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  629. ReadError: nil,
  630. WriteError: nil,
  631. }
  632. addTransfer(&transfer)
  633. err = scpCommand.getUploadFileData(0, &transfer)
  634. if err == nil {
  635. t.Errorf("upload must fail, the file is closed")
  636. }
  637. os.Remove(testfile)
  638. }