internal_test.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454
  1. package ftpd
  2. import (
  3. "fmt"
  4. "io/ioutil"
  5. "net"
  6. "os"
  7. "path/filepath"
  8. "runtime"
  9. "testing"
  10. "time"
  11. "github.com/eikenb/pipeat"
  12. "github.com/stretchr/testify/assert"
  13. "github.com/drakkan/sftpgo/common"
  14. "github.com/drakkan/sftpgo/dataprovider"
  15. "github.com/drakkan/sftpgo/vfs"
  16. )
  17. const (
  18. configDir = ".."
  19. )
  20. type mockFTPClientContext struct {
  21. }
  22. func (cc mockFTPClientContext) Path() string {
  23. return ""
  24. }
  25. func (cc mockFTPClientContext) SetDebug(debug bool) {}
  26. func (cc mockFTPClientContext) Debug() bool {
  27. return false
  28. }
  29. func (cc mockFTPClientContext) ID() uint32 {
  30. return 1
  31. }
  32. func (cc mockFTPClientContext) RemoteAddr() net.Addr {
  33. return &net.IPAddr{IP: []byte("127.0.0.1")}
  34. }
  35. func (cc mockFTPClientContext) LocalAddr() net.Addr {
  36. return &net.IPAddr{IP: []byte("127.0.0.1")}
  37. }
  38. func (cc mockFTPClientContext) GetClientVersion() string {
  39. return "mock version"
  40. }
  41. func (cc mockFTPClientContext) Close(code int, message string) error {
  42. return nil
  43. }
  44. // MockOsFs mockable OsFs
  45. type MockOsFs struct {
  46. vfs.Fs
  47. err error
  48. statErr error
  49. isAtomicUploadSupported bool
  50. }
  51. // Name returns the name for the Fs implementation
  52. func (fs MockOsFs) Name() string {
  53. return "mockOsFs"
  54. }
  55. // IsUploadResumeSupported returns true if upload resume is supported
  56. func (MockOsFs) IsUploadResumeSupported() bool {
  57. return false
  58. }
  59. // IsAtomicUploadSupported returns true if atomic upload is supported
  60. func (fs MockOsFs) IsAtomicUploadSupported() bool {
  61. return fs.isAtomicUploadSupported
  62. }
  63. // Stat returns a FileInfo describing the named file
  64. func (fs MockOsFs) Stat(name string) (os.FileInfo, error) {
  65. if fs.statErr != nil {
  66. return nil, fs.statErr
  67. }
  68. return os.Stat(name)
  69. }
  70. // Lstat returns a FileInfo describing the named file
  71. func (fs MockOsFs) Lstat(name string) (os.FileInfo, error) {
  72. if fs.statErr != nil {
  73. return nil, fs.statErr
  74. }
  75. return os.Lstat(name)
  76. }
  77. // Remove removes the named file or (empty) directory.
  78. func (fs MockOsFs) Remove(name string, isDir bool) error {
  79. if fs.err != nil {
  80. return fs.err
  81. }
  82. return os.Remove(name)
  83. }
  84. // Rename renames (moves) source to target
  85. func (fs MockOsFs) Rename(source, target string) error {
  86. if fs.err != nil {
  87. return fs.err
  88. }
  89. return os.Rename(source, target)
  90. }
  91. func newMockOsFs(err, statErr error, atomicUpload bool, connectionID, rootDir string) vfs.Fs {
  92. return &MockOsFs{
  93. Fs: vfs.NewOsFs(connectionID, rootDir, nil),
  94. err: err,
  95. statErr: statErr,
  96. isAtomicUploadSupported: atomicUpload,
  97. }
  98. }
  99. func TestInitialization(t *testing.T) {
  100. c := &Configuration{
  101. BindPort: 2121,
  102. CertificateFile: "acert",
  103. CertificateKeyFile: "akey",
  104. }
  105. err := c.Initialize(configDir)
  106. assert.Error(t, err)
  107. c.CertificateFile = ""
  108. c.CertificateKeyFile = ""
  109. c.BannerFile = "afile"
  110. server, err := NewServer(c, configDir)
  111. if assert.NoError(t, err) {
  112. assert.Equal(t, "", server.initialMsg)
  113. _, err = server.GetTLSConfig()
  114. assert.Error(t, err)
  115. }
  116. err = ReloadTLSCertificate()
  117. assert.NoError(t, err)
  118. }
  119. func TestServerGetSettings(t *testing.T) {
  120. oldConfig := common.Config
  121. c := &Configuration{
  122. BindPort: 2121,
  123. PassivePortRange: PortRange{
  124. Start: 10000,
  125. End: 11000,
  126. },
  127. }
  128. server, err := NewServer(c, configDir)
  129. assert.NoError(t, err)
  130. settings, err := server.GetSettings()
  131. assert.NoError(t, err)
  132. assert.Equal(t, 10000, settings.PassiveTransferPortRange.Start)
  133. assert.Equal(t, 11000, settings.PassiveTransferPortRange.End)
  134. common.Config.ProxyProtocol = 1
  135. common.Config.ProxyAllowed = []string{"invalid"}
  136. _, err = server.GetSettings()
  137. assert.Error(t, err)
  138. server.config.BindPort = 8021
  139. _, err = server.GetSettings()
  140. assert.Error(t, err)
  141. common.Config = oldConfig
  142. }
  143. func TestUserInvalidParams(t *testing.T) {
  144. u := dataprovider.User{
  145. HomeDir: "invalid",
  146. }
  147. c := &Configuration{
  148. BindPort: 2121,
  149. PassivePortRange: PortRange{
  150. Start: 10000,
  151. End: 11000,
  152. },
  153. }
  154. server, err := NewServer(c, configDir)
  155. assert.NoError(t, err)
  156. _, err = server.validateUser(u, mockFTPClientContext{})
  157. assert.Error(t, err)
  158. u.Username = "a"
  159. u.HomeDir = filepath.Clean(os.TempDir())
  160. subDir := "subdir"
  161. mappedPath1 := filepath.Join(os.TempDir(), "vdir1")
  162. vdirPath1 := "/vdir1"
  163. mappedPath2 := filepath.Join(os.TempDir(), "vdir1", subDir)
  164. vdirPath2 := "/vdir2"
  165. u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{
  166. BaseVirtualFolder: vfs.BaseVirtualFolder{
  167. MappedPath: mappedPath1,
  168. },
  169. VirtualPath: vdirPath1,
  170. })
  171. u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{
  172. BaseVirtualFolder: vfs.BaseVirtualFolder{
  173. MappedPath: mappedPath2,
  174. },
  175. VirtualPath: vdirPath2,
  176. })
  177. _, err = server.validateUser(u, mockFTPClientContext{})
  178. assert.Error(t, err)
  179. u.VirtualFolders = nil
  180. _, err = server.validateUser(u, mockFTPClientContext{})
  181. assert.Error(t, err)
  182. }
  183. func TestClientVersion(t *testing.T) {
  184. mockCC := mockFTPClientContext{}
  185. connID := fmt.Sprintf("%v", mockCC.ID())
  186. user := dataprovider.User{}
  187. connection := &Connection{
  188. BaseConnection: common.NewBaseConnection(connID, common.ProtocolFTP, user, nil),
  189. clientContext: mockCC,
  190. }
  191. common.Connections.Add(connection)
  192. stats := common.Connections.GetStats()
  193. if assert.Len(t, stats, 1) {
  194. assert.Equal(t, "mock version", stats[0].ClientVersion)
  195. common.Connections.Remove(connection.GetID())
  196. }
  197. assert.Len(t, common.Connections.GetStats(), 0)
  198. }
  199. func TestDriverMethodsNotImplemented(t *testing.T) {
  200. mockCC := mockFTPClientContext{}
  201. connID := fmt.Sprintf("%v", mockCC.ID())
  202. user := dataprovider.User{}
  203. connection := &Connection{
  204. BaseConnection: common.NewBaseConnection(connID, common.ProtocolFTP, user, nil),
  205. clientContext: mockCC,
  206. }
  207. _, err := connection.Create("")
  208. assert.EqualError(t, err, errNotImplemented.Error())
  209. err = connection.MkdirAll("", os.ModePerm)
  210. assert.EqualError(t, err, errNotImplemented.Error())
  211. _, err = connection.Open("")
  212. assert.EqualError(t, err, errNotImplemented.Error())
  213. _, err = connection.OpenFile("", 0, os.ModePerm)
  214. assert.EqualError(t, err, errNotImplemented.Error())
  215. err = connection.RemoveAll("")
  216. assert.EqualError(t, err, errNotImplemented.Error())
  217. assert.Equal(t, connection.GetID(), connection.Name())
  218. }
  219. func TestResolvePathErrors(t *testing.T) {
  220. user := dataprovider.User{
  221. HomeDir: "invalid",
  222. }
  223. user.Permissions = make(map[string][]string)
  224. user.Permissions["/"] = []string{dataprovider.PermAny}
  225. mockCC := mockFTPClientContext{}
  226. connID := fmt.Sprintf("%v", mockCC.ID())
  227. fs := vfs.NewOsFs(connID, user.HomeDir, nil)
  228. connection := &Connection{
  229. BaseConnection: common.NewBaseConnection(connID, common.ProtocolFTP, user, fs),
  230. clientContext: mockCC,
  231. }
  232. err := connection.Mkdir("", os.ModePerm)
  233. if assert.Error(t, err) {
  234. assert.EqualError(t, err, common.ErrGenericFailure.Error())
  235. }
  236. err = connection.Remove("")
  237. if assert.Error(t, err) {
  238. assert.EqualError(t, err, common.ErrGenericFailure.Error())
  239. }
  240. err = connection.RemoveDir("")
  241. if assert.Error(t, err) {
  242. assert.EqualError(t, err, common.ErrGenericFailure.Error())
  243. }
  244. err = connection.Rename("", "")
  245. if assert.Error(t, err) {
  246. assert.EqualError(t, err, common.ErrGenericFailure.Error())
  247. }
  248. err = connection.Symlink("", "")
  249. if assert.Error(t, err) {
  250. assert.EqualError(t, err, common.ErrGenericFailure.Error())
  251. }
  252. _, err = connection.Stat("")
  253. if assert.Error(t, err) {
  254. assert.EqualError(t, err, common.ErrGenericFailure.Error())
  255. }
  256. err = connection.Chmod("", os.ModePerm)
  257. if assert.Error(t, err) {
  258. assert.EqualError(t, err, common.ErrGenericFailure.Error())
  259. }
  260. err = connection.Chtimes("", time.Now(), time.Now())
  261. if assert.Error(t, err) {
  262. assert.EqualError(t, err, common.ErrGenericFailure.Error())
  263. }
  264. _, err = connection.ReadDir("")
  265. if assert.Error(t, err) {
  266. assert.EqualError(t, err, common.ErrGenericFailure.Error())
  267. }
  268. _, err = connection.GetHandle("", 0, 0)
  269. if assert.Error(t, err) {
  270. assert.EqualError(t, err, common.ErrGenericFailure.Error())
  271. }
  272. }
  273. func TestUploadFileStatError(t *testing.T) {
  274. if runtime.GOOS == "windows" {
  275. t.Skip("this test is not available on Windows")
  276. }
  277. user := dataprovider.User{
  278. Username: "user",
  279. HomeDir: filepath.Clean(os.TempDir()),
  280. }
  281. user.Permissions = make(map[string][]string)
  282. user.Permissions["/"] = []string{dataprovider.PermAny}
  283. mockCC := mockFTPClientContext{}
  284. connID := fmt.Sprintf("%v", mockCC.ID())
  285. fs := vfs.NewOsFs(connID, user.HomeDir, nil)
  286. connection := &Connection{
  287. BaseConnection: common.NewBaseConnection(connID, common.ProtocolFTP, user, fs),
  288. clientContext: mockCC,
  289. }
  290. testFile := filepath.Join(user.HomeDir, "test", "testfile")
  291. err := os.MkdirAll(filepath.Dir(testFile), os.ModePerm)
  292. assert.NoError(t, err)
  293. err = ioutil.WriteFile(testFile, []byte("data"), os.ModePerm)
  294. assert.NoError(t, err)
  295. err = os.Chmod(filepath.Dir(testFile), 0001)
  296. assert.NoError(t, err)
  297. _, err = connection.uploadFile(testFile, "test", 0)
  298. assert.Error(t, err)
  299. err = os.Chmod(filepath.Dir(testFile), os.ModePerm)
  300. assert.NoError(t, err)
  301. err = os.RemoveAll(filepath.Dir(testFile))
  302. assert.NoError(t, err)
  303. }
  304. func TestUploadOverwriteErrors(t *testing.T) {
  305. user := dataprovider.User{
  306. Username: "user",
  307. HomeDir: filepath.Clean(os.TempDir()),
  308. }
  309. user.Permissions = make(map[string][]string)
  310. user.Permissions["/"] = []string{dataprovider.PermAny}
  311. mockCC := mockFTPClientContext{}
  312. connID := fmt.Sprintf("%v", mockCC.ID())
  313. fs := newMockOsFs(nil, nil, false, connID, user.GetHomeDir())
  314. connection := &Connection{
  315. BaseConnection: common.NewBaseConnection(connID, common.ProtocolFTP, user, fs),
  316. clientContext: mockCC,
  317. }
  318. flags := 0
  319. flags |= os.O_APPEND
  320. _, err := connection.handleFTPUploadToExistingFile(flags, "", "", 0, "")
  321. if assert.Error(t, err) {
  322. assert.EqualError(t, err, common.ErrOpUnsupported.Error())
  323. }
  324. f, err := ioutil.TempFile("", "temp")
  325. assert.NoError(t, err)
  326. err = f.Close()
  327. assert.NoError(t, err)
  328. flags = 0
  329. flags |= os.O_CREATE
  330. flags |= os.O_TRUNC
  331. tr, err := connection.handleFTPUploadToExistingFile(flags, f.Name(), f.Name(), 123, f.Name())
  332. if assert.NoError(t, err) {
  333. transfer := tr.(*transfer)
  334. transfers := connection.GetTransfers()
  335. if assert.Equal(t, 1, len(transfers)) {
  336. assert.Equal(t, transfers[0].ID, transfer.GetID())
  337. assert.Equal(t, int64(123), transfer.InitialSize)
  338. err = transfer.Close()
  339. assert.NoError(t, err)
  340. assert.Equal(t, 0, len(connection.GetTransfers()))
  341. }
  342. }
  343. err = os.Remove(f.Name())
  344. assert.NoError(t, err)
  345. _, err = connection.handleFTPUploadToExistingFile(0, filepath.Join(os.TempDir(), "sub", "file"),
  346. filepath.Join(os.TempDir(), "sub", "file1"), 0, "/sub/file1")
  347. assert.Error(t, err)
  348. connection.Fs = vfs.NewOsFs(connID, user.GetHomeDir(), nil)
  349. _, err = connection.handleFTPUploadToExistingFile(0, "missing1", "missing2", 0, "missing")
  350. assert.Error(t, err)
  351. }
  352. func TestTransferErrors(t *testing.T) {
  353. testfile := "testfile"
  354. file, err := os.Create(testfile)
  355. assert.NoError(t, err)
  356. user := dataprovider.User{
  357. Username: "user",
  358. HomeDir: filepath.Clean(os.TempDir()),
  359. }
  360. user.Permissions = make(map[string][]string)
  361. user.Permissions["/"] = []string{dataprovider.PermAny}
  362. mockCC := mockFTPClientContext{}
  363. connID := fmt.Sprintf("%v", mockCC.ID())
  364. fs := newMockOsFs(nil, nil, false, connID, user.GetHomeDir())
  365. connection := &Connection{
  366. BaseConnection: common.NewBaseConnection(connID, common.ProtocolFTP, user, fs),
  367. clientContext: mockCC,
  368. }
  369. baseTransfer := common.NewBaseTransfer(file, connection.BaseConnection, nil, file.Name(), testfile, common.TransferDownload,
  370. 0, 0, 0, false, fs)
  371. tr := newTransfer(baseTransfer, nil, nil, 0)
  372. err = tr.Close()
  373. assert.NoError(t, err)
  374. _, err = tr.Seek(10, 0)
  375. assert.Error(t, err)
  376. buf := make([]byte, 64)
  377. _, err = tr.Read(buf)
  378. assert.Error(t, err)
  379. err = tr.Close()
  380. if assert.Error(t, err) {
  381. assert.EqualError(t, err, common.ErrTransferClosed.Error())
  382. }
  383. assert.Len(t, connection.GetTransfers(), 0)
  384. r, _, err := pipeat.Pipe()
  385. assert.NoError(t, err)
  386. baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testfile, testfile,
  387. common.TransferUpload, 0, 0, 0, false, fs)
  388. tr = newTransfer(baseTransfer, nil, r, 10)
  389. pos, err := tr.Seek(10, 0)
  390. assert.NoError(t, err)
  391. assert.Equal(t, pos, tr.expectedOffset)
  392. err = tr.closeIO()
  393. assert.NoError(t, err)
  394. r, w, err := pipeat.Pipe()
  395. assert.NoError(t, err)
  396. pipeWriter := vfs.NewPipeWriter(w)
  397. baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testfile, testfile,
  398. common.TransferUpload, 0, 0, 0, false, fs)
  399. tr = newTransfer(baseTransfer, pipeWriter, nil, 0)
  400. err = r.Close()
  401. assert.NoError(t, err)
  402. errFake := fmt.Errorf("fake upload error")
  403. go func() {
  404. time.Sleep(100 * time.Millisecond)
  405. pipeWriter.Done(errFake)
  406. }()
  407. err = tr.closeIO()
  408. assert.EqualError(t, err, errFake.Error())
  409. _, err = tr.Seek(1, 0)
  410. if assert.Error(t, err) {
  411. assert.EqualError(t, err, common.ErrOpUnsupported.Error())
  412. }
  413. err = os.Remove(testfile)
  414. assert.NoError(t, err)
  415. }