transfer_test.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455
  1. package common
  2. import (
  3. "errors"
  4. "os"
  5. "path/filepath"
  6. "testing"
  7. "time"
  8. "github.com/sftpgo/sdk"
  9. "github.com/stretchr/testify/assert"
  10. "github.com/stretchr/testify/require"
  11. "github.com/drakkan/sftpgo/v2/dataprovider"
  12. "github.com/drakkan/sftpgo/v2/kms"
  13. "github.com/drakkan/sftpgo/v2/vfs"
  14. )
  15. func TestTransferUpdateQuota(t *testing.T) {
  16. conn := NewBaseConnection("", ProtocolSFTP, "", "", dataprovider.User{})
  17. transfer := BaseTransfer{
  18. Connection: conn,
  19. transferType: TransferUpload,
  20. BytesReceived: 123,
  21. Fs: vfs.NewOsFs("", os.TempDir(), ""),
  22. }
  23. errFake := errors.New("fake error")
  24. transfer.TransferError(errFake)
  25. assert.False(t, transfer.updateQuota(1, 0))
  26. err := transfer.Close()
  27. if assert.Error(t, err) {
  28. assert.EqualError(t, err, errFake.Error())
  29. }
  30. mappedPath := filepath.Join(os.TempDir(), "vdir")
  31. vdirPath := "/vdir"
  32. conn.User.VirtualFolders = append(conn.User.VirtualFolders, vfs.VirtualFolder{
  33. BaseVirtualFolder: vfs.BaseVirtualFolder{
  34. MappedPath: mappedPath,
  35. },
  36. VirtualPath: vdirPath,
  37. QuotaFiles: -1,
  38. QuotaSize: -1,
  39. })
  40. transfer.ErrTransfer = nil
  41. transfer.BytesReceived = 1
  42. transfer.requestPath = "/vdir/file"
  43. assert.True(t, transfer.updateQuota(1, 0))
  44. err = transfer.Close()
  45. assert.NoError(t, err)
  46. }
  47. func TestTransferThrottling(t *testing.T) {
  48. u := dataprovider.User{
  49. BaseUser: sdk.BaseUser{
  50. Username: "test",
  51. UploadBandwidth: 50,
  52. DownloadBandwidth: 40,
  53. },
  54. }
  55. fs := vfs.NewOsFs("", os.TempDir(), "")
  56. testFileSize := int64(131072)
  57. wantedUploadElapsed := 1000 * (testFileSize / 1024) / u.UploadBandwidth
  58. wantedDownloadElapsed := 1000 * (testFileSize / 1024) / u.DownloadBandwidth
  59. // some tolerance
  60. wantedUploadElapsed -= wantedDownloadElapsed / 10
  61. wantedDownloadElapsed -= wantedDownloadElapsed / 10
  62. conn := NewBaseConnection("id", ProtocolSCP, "", "", u)
  63. transfer := NewBaseTransfer(nil, conn, nil, "", "", "", TransferUpload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{})
  64. transfer.BytesReceived = testFileSize
  65. transfer.Connection.UpdateLastActivity()
  66. startTime := transfer.Connection.GetLastActivity()
  67. transfer.HandleThrottle()
  68. elapsed := time.Since(startTime).Nanoseconds() / 1000000
  69. assert.GreaterOrEqual(t, elapsed, wantedUploadElapsed, "upload bandwidth throttling not respected")
  70. err := transfer.Close()
  71. assert.NoError(t, err)
  72. transfer = NewBaseTransfer(nil, conn, nil, "", "", "", TransferDownload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{})
  73. transfer.BytesSent = testFileSize
  74. transfer.Connection.UpdateLastActivity()
  75. startTime = transfer.Connection.GetLastActivity()
  76. transfer.HandleThrottle()
  77. elapsed = time.Since(startTime).Nanoseconds() / 1000000
  78. assert.GreaterOrEqual(t, elapsed, wantedDownloadElapsed, "download bandwidth throttling not respected")
  79. err = transfer.Close()
  80. assert.NoError(t, err)
  81. }
  82. func TestRealPath(t *testing.T) {
  83. testFile := filepath.Join(os.TempDir(), "afile.txt")
  84. fs := vfs.NewOsFs("123", os.TempDir(), "")
  85. u := dataprovider.User{
  86. BaseUser: sdk.BaseUser{
  87. Username: "user",
  88. HomeDir: os.TempDir(),
  89. },
  90. }
  91. u.Permissions = make(map[string][]string)
  92. u.Permissions["/"] = []string{dataprovider.PermAny}
  93. file, err := os.Create(testFile)
  94. require.NoError(t, err)
  95. conn := NewBaseConnection(fs.ConnectionID(), ProtocolSFTP, "", "", u)
  96. transfer := NewBaseTransfer(file, conn, nil, testFile, testFile, "/transfer_test_file",
  97. TransferUpload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{})
  98. rPath := transfer.GetRealFsPath(testFile)
  99. assert.Equal(t, testFile, rPath)
  100. rPath = conn.getRealFsPath(testFile)
  101. assert.Equal(t, testFile, rPath)
  102. err = transfer.Close()
  103. assert.NoError(t, err)
  104. err = file.Close()
  105. assert.NoError(t, err)
  106. transfer.File = nil
  107. rPath = transfer.GetRealFsPath(testFile)
  108. assert.Equal(t, testFile, rPath)
  109. rPath = transfer.GetRealFsPath("")
  110. assert.Empty(t, rPath)
  111. err = os.Remove(testFile)
  112. assert.NoError(t, err)
  113. assert.Len(t, conn.GetTransfers(), 0)
  114. }
  115. func TestTruncate(t *testing.T) {
  116. testFile := filepath.Join(os.TempDir(), "transfer_test_file")
  117. fs := vfs.NewOsFs("123", os.TempDir(), "")
  118. u := dataprovider.User{
  119. BaseUser: sdk.BaseUser{
  120. Username: "user",
  121. HomeDir: os.TempDir(),
  122. },
  123. }
  124. u.Permissions = make(map[string][]string)
  125. u.Permissions["/"] = []string{dataprovider.PermAny}
  126. file, err := os.Create(testFile)
  127. if !assert.NoError(t, err) {
  128. assert.FailNow(t, "unable to open test file")
  129. }
  130. _, err = file.Write([]byte("hello"))
  131. assert.NoError(t, err)
  132. conn := NewBaseConnection(fs.ConnectionID(), ProtocolSFTP, "", "", u)
  133. transfer := NewBaseTransfer(file, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload, 0, 5,
  134. 100, 0, false, fs, dataprovider.TransferQuota{})
  135. err = conn.SetStat("/transfer_test_file", &StatAttributes{
  136. Size: 2,
  137. Flags: StatAttrSize,
  138. })
  139. assert.NoError(t, err)
  140. assert.Equal(t, int64(103), transfer.MaxWriteSize)
  141. err = transfer.Close()
  142. assert.NoError(t, err)
  143. err = file.Close()
  144. assert.NoError(t, err)
  145. fi, err := os.Stat(testFile)
  146. if assert.NoError(t, err) {
  147. assert.Equal(t, int64(2), fi.Size())
  148. }
  149. transfer = NewBaseTransfer(file, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload, 0, 0,
  150. 100, 0, true, fs, dataprovider.TransferQuota{})
  151. // file.Stat will fail on a closed file
  152. err = conn.SetStat("/transfer_test_file", &StatAttributes{
  153. Size: 2,
  154. Flags: StatAttrSize,
  155. })
  156. assert.Error(t, err)
  157. err = transfer.Close()
  158. assert.NoError(t, err)
  159. transfer = NewBaseTransfer(nil, conn, nil, testFile, testFile, "", TransferUpload, 0, 0, 0, 0, true,
  160. fs, dataprovider.TransferQuota{})
  161. _, err = transfer.Truncate("mismatch", 0)
  162. assert.EqualError(t, err, errTransferMismatch.Error())
  163. _, err = transfer.Truncate(testFile, 0)
  164. assert.NoError(t, err)
  165. _, err = transfer.Truncate(testFile, 1)
  166. assert.EqualError(t, err, vfs.ErrVfsUnsupported.Error())
  167. err = transfer.Close()
  168. assert.NoError(t, err)
  169. err = os.Remove(testFile)
  170. assert.NoError(t, err)
  171. assert.Len(t, conn.GetTransfers(), 0)
  172. }
  173. func TestTransferErrors(t *testing.T) {
  174. isCancelled := false
  175. cancelFn := func() {
  176. isCancelled = true
  177. }
  178. testFile := filepath.Join(os.TempDir(), "transfer_test_file")
  179. fs := vfs.NewOsFs("id", os.TempDir(), "")
  180. u := dataprovider.User{
  181. BaseUser: sdk.BaseUser{
  182. Username: "test",
  183. HomeDir: os.TempDir(),
  184. },
  185. }
  186. err := os.WriteFile(testFile, []byte("test data"), os.ModePerm)
  187. assert.NoError(t, err)
  188. file, err := os.Open(testFile)
  189. if !assert.NoError(t, err) {
  190. assert.FailNow(t, "unable to open test file")
  191. }
  192. conn := NewBaseConnection("id", ProtocolSFTP, "", "", u)
  193. transfer := NewBaseTransfer(file, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload,
  194. 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{})
  195. assert.Nil(t, transfer.cancelFn)
  196. assert.Equal(t, testFile, transfer.GetFsPath())
  197. transfer.SetCancelFn(cancelFn)
  198. errFake := errors.New("err fake")
  199. transfer.BytesReceived = 9
  200. transfer.TransferError(ErrQuotaExceeded)
  201. assert.True(t, isCancelled)
  202. transfer.TransferError(errFake)
  203. assert.Error(t, transfer.ErrTransfer, ErrQuotaExceeded.Error())
  204. // the file is closed from the embedding struct before to call close
  205. err = file.Close()
  206. assert.NoError(t, err)
  207. err = transfer.Close()
  208. if assert.Error(t, err) {
  209. assert.Error(t, err, ErrQuotaExceeded.Error())
  210. }
  211. assert.NoFileExists(t, testFile)
  212. err = os.WriteFile(testFile, []byte("test data"), os.ModePerm)
  213. assert.NoError(t, err)
  214. file, err = os.Open(testFile)
  215. if !assert.NoError(t, err) {
  216. assert.FailNow(t, "unable to open test file")
  217. }
  218. fsPath := filepath.Join(os.TempDir(), "test_file")
  219. transfer = NewBaseTransfer(file, conn, nil, fsPath, file.Name(), "/test_file", TransferUpload, 0, 0, 0, 0, true,
  220. fs, dataprovider.TransferQuota{})
  221. transfer.BytesReceived = 9
  222. transfer.TransferError(errFake)
  223. assert.Error(t, transfer.ErrTransfer, errFake.Error())
  224. // the file is closed from the embedding struct before to call close
  225. err = file.Close()
  226. assert.NoError(t, err)
  227. err = transfer.Close()
  228. if assert.Error(t, err) {
  229. assert.Error(t, err, errFake.Error())
  230. }
  231. assert.NoFileExists(t, testFile)
  232. err = os.WriteFile(testFile, []byte("test data"), os.ModePerm)
  233. assert.NoError(t, err)
  234. file, err = os.Open(testFile)
  235. if !assert.NoError(t, err) {
  236. assert.FailNow(t, "unable to open test file")
  237. }
  238. transfer = NewBaseTransfer(file, conn, nil, fsPath, file.Name(), "/test_file", TransferUpload, 0, 0, 0, 0, true,
  239. fs, dataprovider.TransferQuota{})
  240. transfer.BytesReceived = 9
  241. // the file is closed from the embedding struct before to call close
  242. err = file.Close()
  243. assert.NoError(t, err)
  244. err = transfer.Close()
  245. assert.NoError(t, err)
  246. assert.NoFileExists(t, testFile)
  247. assert.FileExists(t, fsPath)
  248. err = os.Remove(fsPath)
  249. assert.NoError(t, err)
  250. assert.Len(t, conn.GetTransfers(), 0)
  251. }
  252. func TestRemovePartialCryptoFile(t *testing.T) {
  253. testFile := filepath.Join(os.TempDir(), "transfer_test_file")
  254. fs, err := vfs.NewCryptFs("id", os.TempDir(), "", vfs.CryptFsConfig{Passphrase: kms.NewPlainSecret("secret")})
  255. require.NoError(t, err)
  256. u := dataprovider.User{
  257. BaseUser: sdk.BaseUser{
  258. Username: "test",
  259. HomeDir: os.TempDir(),
  260. },
  261. }
  262. conn := NewBaseConnection(fs.ConnectionID(), ProtocolSFTP, "", "", u)
  263. transfer := NewBaseTransfer(nil, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload,
  264. 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{})
  265. transfer.ErrTransfer = errors.New("test error")
  266. _, err = transfer.getUploadFileSize()
  267. assert.Error(t, err)
  268. err = os.WriteFile(testFile, []byte("test data"), os.ModePerm)
  269. assert.NoError(t, err)
  270. size, err := transfer.getUploadFileSize()
  271. assert.NoError(t, err)
  272. assert.Equal(t, int64(9), size)
  273. assert.NoFileExists(t, testFile)
  274. }
  275. func TestFTPMode(t *testing.T) {
  276. conn := NewBaseConnection("", ProtocolFTP, "", "", dataprovider.User{})
  277. transfer := BaseTransfer{
  278. Connection: conn,
  279. transferType: TransferUpload,
  280. BytesReceived: 123,
  281. Fs: vfs.NewOsFs("", os.TempDir(), ""),
  282. }
  283. assert.Empty(t, transfer.ftpMode)
  284. transfer.SetFtpMode("active")
  285. assert.Equal(t, "active", transfer.ftpMode)
  286. }
  287. func TestTransferQuota(t *testing.T) {
  288. user := dataprovider.User{
  289. BaseUser: sdk.BaseUser{
  290. TotalDataTransfer: -1,
  291. UploadDataTransfer: -1,
  292. DownloadDataTransfer: -1,
  293. },
  294. }
  295. user.Filters.DataTransferLimits = []sdk.DataTransferLimit{
  296. {
  297. Sources: []string{"127.0.0.1/32", "192.168.1.0/24"},
  298. TotalDataTransfer: 100,
  299. UploadDataTransfer: 0,
  300. DownloadDataTransfer: 0,
  301. },
  302. {
  303. Sources: []string{"172.16.0.0/24"},
  304. TotalDataTransfer: 0,
  305. UploadDataTransfer: 120,
  306. DownloadDataTransfer: 150,
  307. },
  308. }
  309. ul, dl, total := user.GetDataTransferLimits("127.0.1.1")
  310. assert.Equal(t, int64(0), ul)
  311. assert.Equal(t, int64(0), dl)
  312. assert.Equal(t, int64(0), total)
  313. ul, dl, total = user.GetDataTransferLimits("127.0.0.1")
  314. assert.Equal(t, int64(0), ul)
  315. assert.Equal(t, int64(0), dl)
  316. assert.Equal(t, int64(100*1048576), total)
  317. ul, dl, total = user.GetDataTransferLimits("192.168.1.4")
  318. assert.Equal(t, int64(0), ul)
  319. assert.Equal(t, int64(0), dl)
  320. assert.Equal(t, int64(100*1048576), total)
  321. ul, dl, total = user.GetDataTransferLimits("172.16.0.2")
  322. assert.Equal(t, int64(120*1048576), ul)
  323. assert.Equal(t, int64(150*1048576), dl)
  324. assert.Equal(t, int64(0), total)
  325. transferQuota := dataprovider.TransferQuota{}
  326. assert.True(t, transferQuota.HasDownloadSpace())
  327. assert.True(t, transferQuota.HasUploadSpace())
  328. transferQuota.TotalSize = -1
  329. transferQuota.ULSize = -1
  330. transferQuota.DLSize = -1
  331. assert.True(t, transferQuota.HasDownloadSpace())
  332. assert.True(t, transferQuota.HasUploadSpace())
  333. transferQuota.TotalSize = 100
  334. transferQuota.AllowedTotalSize = 10
  335. assert.True(t, transferQuota.HasDownloadSpace())
  336. assert.True(t, transferQuota.HasUploadSpace())
  337. transferQuota.AllowedTotalSize = 0
  338. assert.False(t, transferQuota.HasDownloadSpace())
  339. assert.False(t, transferQuota.HasUploadSpace())
  340. transferQuota.TotalSize = 0
  341. transferQuota.DLSize = 100
  342. transferQuota.ULSize = 50
  343. transferQuota.AllowedTotalSize = 0
  344. assert.False(t, transferQuota.HasDownloadSpace())
  345. assert.False(t, transferQuota.HasUploadSpace())
  346. transferQuota.AllowedDLSize = 1
  347. transferQuota.AllowedULSize = 1
  348. assert.True(t, transferQuota.HasDownloadSpace())
  349. assert.True(t, transferQuota.HasUploadSpace())
  350. transferQuota.AllowedDLSize = -10
  351. transferQuota.AllowedULSize = -1
  352. assert.False(t, transferQuota.HasDownloadSpace())
  353. assert.False(t, transferQuota.HasUploadSpace())
  354. conn := NewBaseConnection("", ProtocolSFTP, "", "", user)
  355. transfer := NewBaseTransfer(nil, conn, nil, "file.txt", "file.txt", "/transfer_test_file", TransferUpload,
  356. 0, 0, 0, 0, true, vfs.NewOsFs("", os.TempDir(), ""), dataprovider.TransferQuota{})
  357. err := transfer.CheckRead()
  358. assert.NoError(t, err)
  359. err = transfer.CheckWrite()
  360. assert.NoError(t, err)
  361. transfer.transferQuota = dataprovider.TransferQuota{
  362. AllowedTotalSize: 10,
  363. }
  364. transfer.BytesReceived = 5
  365. transfer.BytesSent = 4
  366. err = transfer.CheckRead()
  367. assert.NoError(t, err)
  368. err = transfer.CheckWrite()
  369. assert.NoError(t, err)
  370. transfer.BytesSent = 6
  371. err = transfer.CheckRead()
  372. if assert.Error(t, err) {
  373. assert.Contains(t, err.Error(), ErrReadQuotaExceeded.Error())
  374. }
  375. err = transfer.CheckWrite()
  376. assert.True(t, conn.IsQuotaExceededError(err))
  377. transferQuota = dataprovider.TransferQuota{
  378. AllowedTotalSize: 0,
  379. AllowedULSize: 10,
  380. AllowedDLSize: 5,
  381. }
  382. transfer.transferQuota = transferQuota
  383. assert.Equal(t, transferQuota, transfer.GetTransferQuota())
  384. err = transfer.CheckRead()
  385. if assert.Error(t, err) {
  386. assert.Contains(t, err.Error(), ErrReadQuotaExceeded.Error())
  387. }
  388. err = transfer.CheckWrite()
  389. assert.NoError(t, err)
  390. transfer.BytesReceived = 11
  391. err = transfer.CheckRead()
  392. if assert.Error(t, err) {
  393. assert.Contains(t, err.Error(), ErrReadQuotaExceeded.Error())
  394. }
  395. err = transfer.CheckWrite()
  396. assert.True(t, conn.IsQuotaExceededError(err))
  397. }
  398. func TestUploadOutsideHomeRenameError(t *testing.T) {
  399. oldTempPath := Config.TempPath
  400. conn := NewBaseConnection("", ProtocolSFTP, "", "", dataprovider.User{})
  401. transfer := BaseTransfer{
  402. Connection: conn,
  403. transferType: TransferUpload,
  404. BytesReceived: 123,
  405. Fs: vfs.NewOsFs("", filepath.Join(os.TempDir(), "home"), ""),
  406. }
  407. fileName := filepath.Join(os.TempDir(), "_temp")
  408. err := os.WriteFile(fileName, []byte(`data`), 0644)
  409. assert.NoError(t, err)
  410. transfer.effectiveFsPath = fileName
  411. res := transfer.checkUploadOutsideHomeDir(os.ErrPermission)
  412. assert.Equal(t, 0, res)
  413. Config.TempPath = filepath.Clean(os.TempDir())
  414. res = transfer.checkUploadOutsideHomeDir(nil)
  415. assert.Equal(t, 0, res)
  416. assert.Greater(t, transfer.BytesReceived, int64(0))
  417. res = transfer.checkUploadOutsideHomeDir(os.ErrPermission)
  418. assert.Equal(t, 1, res)
  419. assert.Equal(t, int64(0), transfer.BytesReceived)
  420. assert.NoFileExists(t, fileName)
  421. Config.TempPath = oldTempPath
  422. }