common_test.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536
  1. package common
  2. import (
  3. "fmt"
  4. "net"
  5. "net/http"
  6. "os"
  7. "os/exec"
  8. "runtime"
  9. "strings"
  10. "sync/atomic"
  11. "testing"
  12. "time"
  13. "github.com/rs/zerolog"
  14. "github.com/spf13/viper"
  15. "github.com/stretchr/testify/assert"
  16. "github.com/drakkan/sftpgo/dataprovider"
  17. "github.com/drakkan/sftpgo/httpclient"
  18. "github.com/drakkan/sftpgo/logger"
  19. "github.com/drakkan/sftpgo/vfs"
  20. )
  21. const (
  22. logSenderTest = "common_test"
  23. httpAddr = "127.0.0.1:9999"
  24. httpProxyAddr = "127.0.0.1:7777"
  25. configDir = ".."
  26. osWindows = "windows"
  27. userTestUsername = "common_test_username"
  28. userTestPwd = "common_test_pwd"
  29. )
  30. type providerConf struct {
  31. Config dataprovider.Config `json:"data_provider" mapstructure:"data_provider"`
  32. }
  33. type fakeConnection struct {
  34. *BaseConnection
  35. command string
  36. }
  37. func (c *fakeConnection) AddUser(user dataprovider.User) error {
  38. fs, err := user.GetFilesystem(c.GetID())
  39. if err != nil {
  40. return err
  41. }
  42. c.BaseConnection.User = user
  43. c.BaseConnection.Fs = fs
  44. return nil
  45. }
  46. func (c *fakeConnection) Disconnect() error {
  47. Connections.Remove(c.GetID())
  48. return nil
  49. }
  50. func (c *fakeConnection) GetClientVersion() string {
  51. return ""
  52. }
  53. func (c *fakeConnection) GetCommand() string {
  54. return c.command
  55. }
  56. func (c *fakeConnection) GetRemoteAddress() string {
  57. return ""
  58. }
  59. type customNetConn struct {
  60. net.Conn
  61. id string
  62. isClosed bool
  63. }
  64. func (c *customNetConn) Close() error {
  65. Connections.RemoveSSHConnection(c.id)
  66. c.isClosed = true
  67. return c.Conn.Close()
  68. }
  69. func TestMain(m *testing.M) {
  70. logfilePath := "common_test.log"
  71. logger.InitLogger(logfilePath, 5, 1, 28, false, zerolog.DebugLevel)
  72. viper.SetEnvPrefix("sftpgo")
  73. replacer := strings.NewReplacer(".", "__")
  74. viper.SetEnvKeyReplacer(replacer)
  75. viper.SetConfigName("sftpgo")
  76. viper.AutomaticEnv()
  77. viper.AllowEmptyEnv(true)
  78. driver, err := initializeDataprovider(-1)
  79. if err != nil {
  80. logger.WarnToConsole("error initializing data provider: %v", err)
  81. os.Exit(1)
  82. }
  83. logger.InfoToConsole("Starting COMMON tests, provider: %v", driver)
  84. Initialize(Configuration{})
  85. httpConfig := httpclient.Config{
  86. Timeout: 5,
  87. }
  88. httpConfig.Initialize(configDir)
  89. go func() {
  90. // start a test HTTP server to receive action notifications
  91. http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
  92. fmt.Fprintf(w, "OK\n")
  93. })
  94. http.HandleFunc("/404", func(w http.ResponseWriter, r *http.Request) {
  95. w.WriteHeader(http.StatusNotFound)
  96. fmt.Fprintf(w, "Not found\n")
  97. })
  98. if err := http.ListenAndServe(httpAddr, nil); err != nil {
  99. logger.ErrorToConsole("could not start HTTP notification server: %v", err)
  100. os.Exit(1)
  101. }
  102. }()
  103. go func() {
  104. Config.ProxyProtocol = 2
  105. listener, err := net.Listen("tcp", httpProxyAddr)
  106. if err != nil {
  107. logger.ErrorToConsole("error creating listener for proxy protocol server: %v", err)
  108. os.Exit(1)
  109. }
  110. proxyListener, err := Config.GetProxyListener(listener)
  111. if err != nil {
  112. logger.ErrorToConsole("error creating proxy protocol listener: %v", err)
  113. os.Exit(1)
  114. }
  115. Config.ProxyProtocol = 0
  116. s := &http.Server{}
  117. if err := s.Serve(proxyListener); err != nil {
  118. logger.ErrorToConsole("could not start HTTP proxy protocol server: %v", err)
  119. os.Exit(1)
  120. }
  121. }()
  122. waitTCPListening(httpAddr)
  123. waitTCPListening(httpProxyAddr)
  124. exitCode := m.Run()
  125. os.Remove(logfilePath) //nolint:errcheck
  126. os.Exit(exitCode)
  127. }
  128. func waitTCPListening(address string) {
  129. for {
  130. conn, err := net.Dial("tcp", address)
  131. if err != nil {
  132. logger.WarnToConsole("tcp server %v not listening: %v\n", address, err)
  133. time.Sleep(100 * time.Millisecond)
  134. continue
  135. }
  136. logger.InfoToConsole("tcp server %v now listening\n", address)
  137. conn.Close()
  138. break
  139. }
  140. }
  141. func initializeDataprovider(trackQuota int) (string, error) {
  142. configDir := ".."
  143. viper.AddConfigPath(configDir)
  144. if err := viper.ReadInConfig(); err != nil {
  145. return "", err
  146. }
  147. var cfg providerConf
  148. if err := viper.Unmarshal(&cfg); err != nil {
  149. return "", err
  150. }
  151. if trackQuota >= 0 && trackQuota <= 2 {
  152. cfg.Config.TrackQuota = trackQuota
  153. }
  154. return cfg.Config.Driver, dataprovider.Initialize(cfg.Config, configDir)
  155. }
  156. func closeDataprovider() error {
  157. return dataprovider.Close()
  158. }
  159. func TestSSHConnections(t *testing.T) {
  160. conn1, conn2 := net.Pipe()
  161. now := time.Now()
  162. sshConn1 := NewSSHConnection("id1", conn1)
  163. sshConn2 := NewSSHConnection("id2", conn2)
  164. sshConn3 := NewSSHConnection("id3", conn2)
  165. assert.Equal(t, "id1", sshConn1.GetID())
  166. assert.Equal(t, "id2", sshConn2.GetID())
  167. assert.Equal(t, "id3", sshConn3.GetID())
  168. sshConn1.UpdateLastActivity()
  169. assert.GreaterOrEqual(t, sshConn1.GetLastActivity().UnixNano(), now.UnixNano())
  170. Connections.AddSSHConnection(sshConn1)
  171. Connections.AddSSHConnection(sshConn2)
  172. Connections.AddSSHConnection(sshConn3)
  173. Connections.RLock()
  174. assert.Len(t, Connections.sshConnections, 3)
  175. Connections.RUnlock()
  176. Connections.RemoveSSHConnection(sshConn1.id)
  177. Connections.RLock()
  178. assert.Len(t, Connections.sshConnections, 2)
  179. assert.Equal(t, sshConn3.id, Connections.sshConnections[0].id)
  180. assert.Equal(t, sshConn2.id, Connections.sshConnections[1].id)
  181. Connections.RUnlock()
  182. Connections.RemoveSSHConnection(sshConn1.id)
  183. Connections.RLock()
  184. assert.Len(t, Connections.sshConnections, 2)
  185. assert.Equal(t, sshConn3.id, Connections.sshConnections[0].id)
  186. assert.Equal(t, sshConn2.id, Connections.sshConnections[1].id)
  187. Connections.RUnlock()
  188. Connections.RemoveSSHConnection(sshConn2.id)
  189. Connections.RLock()
  190. assert.Len(t, Connections.sshConnections, 1)
  191. assert.Equal(t, sshConn3.id, Connections.sshConnections[0].id)
  192. Connections.RUnlock()
  193. Connections.RemoveSSHConnection(sshConn3.id)
  194. Connections.RLock()
  195. assert.Len(t, Connections.sshConnections, 0)
  196. Connections.RUnlock()
  197. assert.NoError(t, sshConn1.Close())
  198. assert.NoError(t, sshConn2.Close())
  199. assert.NoError(t, sshConn3.Close())
  200. }
  201. func TestIdleConnections(t *testing.T) {
  202. configCopy := Config
  203. Config.IdleTimeout = 1
  204. Initialize(Config)
  205. conn1, conn2 := net.Pipe()
  206. customConn1 := &customNetConn{
  207. Conn: conn1,
  208. id: "id1",
  209. }
  210. customConn2 := &customNetConn{
  211. Conn: conn2,
  212. id: "id2",
  213. }
  214. sshConn1 := NewSSHConnection(customConn1.id, customConn1)
  215. sshConn2 := NewSSHConnection(customConn2.id, customConn2)
  216. username := "test_user"
  217. user := dataprovider.User{
  218. Username: username,
  219. }
  220. c := NewBaseConnection(sshConn1.id+"_1", ProtocolSFTP, user, nil)
  221. c.lastActivity = time.Now().Add(-24 * time.Hour).UnixNano()
  222. fakeConn := &fakeConnection{
  223. BaseConnection: c,
  224. }
  225. // both ssh connections are expired but they should get removed only
  226. // if there is no associated connection
  227. sshConn1.lastActivity = c.lastActivity
  228. sshConn2.lastActivity = c.lastActivity
  229. Connections.AddSSHConnection(sshConn1)
  230. Connections.Add(fakeConn)
  231. assert.Equal(t, Connections.GetActiveSessions(username), 1)
  232. c = NewBaseConnection(sshConn2.id+"_1", ProtocolSSH, user, nil)
  233. fakeConn = &fakeConnection{
  234. BaseConnection: c,
  235. }
  236. Connections.AddSSHConnection(sshConn2)
  237. Connections.Add(fakeConn)
  238. assert.Equal(t, Connections.GetActiveSessions(username), 2)
  239. cFTP := NewBaseConnection("id2", ProtocolFTP, dataprovider.User{}, nil)
  240. cFTP.lastActivity = time.Now().UnixNano()
  241. fakeConn = &fakeConnection{
  242. BaseConnection: cFTP,
  243. }
  244. Connections.Add(fakeConn)
  245. assert.Equal(t, Connections.GetActiveSessions(username), 2)
  246. assert.Len(t, Connections.GetStats(), 3)
  247. Connections.RLock()
  248. assert.Len(t, Connections.sshConnections, 2)
  249. Connections.RUnlock()
  250. startIdleTimeoutTicker(100 * time.Millisecond)
  251. assert.Eventually(t, func() bool { return Connections.GetActiveSessions(username) == 1 }, 1*time.Second, 200*time.Millisecond)
  252. assert.Eventually(t, func() bool {
  253. Connections.RLock()
  254. defer Connections.RUnlock()
  255. return len(Connections.sshConnections) == 1
  256. }, 1*time.Second, 200*time.Millisecond)
  257. stopIdleTimeoutTicker()
  258. assert.Len(t, Connections.GetStats(), 2)
  259. c.lastActivity = time.Now().Add(-24 * time.Hour).UnixNano()
  260. cFTP.lastActivity = time.Now().Add(-24 * time.Hour).UnixNano()
  261. sshConn2.lastActivity = c.lastActivity
  262. startIdleTimeoutTicker(100 * time.Millisecond)
  263. assert.Eventually(t, func() bool { return len(Connections.GetStats()) == 0 }, 1*time.Second, 200*time.Millisecond)
  264. assert.Eventually(t, func() bool {
  265. Connections.RLock()
  266. defer Connections.RUnlock()
  267. return len(Connections.sshConnections) == 0
  268. }, 1*time.Second, 200*time.Millisecond)
  269. stopIdleTimeoutTicker()
  270. assert.True(t, customConn1.isClosed)
  271. assert.True(t, customConn2.isClosed)
  272. Config = configCopy
  273. }
  274. func TestCloseConnection(t *testing.T) {
  275. c := NewBaseConnection("id", ProtocolSFTP, dataprovider.User{}, nil)
  276. fakeConn := &fakeConnection{
  277. BaseConnection: c,
  278. }
  279. Connections.Add(fakeConn)
  280. assert.Len(t, Connections.GetStats(), 1)
  281. res := Connections.Close(fakeConn.GetID())
  282. assert.True(t, res)
  283. assert.Eventually(t, func() bool { return len(Connections.GetStats()) == 0 }, 300*time.Millisecond, 50*time.Millisecond)
  284. res = Connections.Close(fakeConn.GetID())
  285. assert.False(t, res)
  286. Connections.Remove(fakeConn.GetID())
  287. }
  288. func TestSwapConnection(t *testing.T) {
  289. c := NewBaseConnection("id", ProtocolFTP, dataprovider.User{}, nil)
  290. fakeConn := &fakeConnection{
  291. BaseConnection: c,
  292. }
  293. Connections.Add(fakeConn)
  294. if assert.Len(t, Connections.GetStats(), 1) {
  295. assert.Equal(t, "", Connections.GetStats()[0].Username)
  296. }
  297. c = NewBaseConnection("id", ProtocolFTP, dataprovider.User{
  298. Username: userTestUsername,
  299. }, nil)
  300. fakeConn = &fakeConnection{
  301. BaseConnection: c,
  302. }
  303. err := Connections.Swap(fakeConn)
  304. assert.NoError(t, err)
  305. if assert.Len(t, Connections.GetStats(), 1) {
  306. assert.Equal(t, userTestUsername, Connections.GetStats()[0].Username)
  307. }
  308. res := Connections.Close(fakeConn.GetID())
  309. assert.True(t, res)
  310. assert.Eventually(t, func() bool { return len(Connections.GetStats()) == 0 }, 300*time.Millisecond, 50*time.Millisecond)
  311. err = Connections.Swap(fakeConn)
  312. assert.Error(t, err)
  313. }
  314. func TestAtomicUpload(t *testing.T) {
  315. configCopy := Config
  316. Config.UploadMode = UploadModeStandard
  317. assert.False(t, Config.IsAtomicUploadEnabled())
  318. Config.UploadMode = UploadModeAtomic
  319. assert.True(t, Config.IsAtomicUploadEnabled())
  320. Config.UploadMode = UploadModeAtomicWithResume
  321. assert.True(t, Config.IsAtomicUploadEnabled())
  322. Config = configCopy
  323. }
  324. func TestConnectionStatus(t *testing.T) {
  325. username := "test_user"
  326. user := dataprovider.User{
  327. Username: username,
  328. }
  329. fs := vfs.NewOsFs("", os.TempDir(), nil)
  330. c1 := NewBaseConnection("id1", ProtocolSFTP, user, fs)
  331. fakeConn1 := &fakeConnection{
  332. BaseConnection: c1,
  333. }
  334. t1 := NewBaseTransfer(nil, c1, nil, "/p1", "/r1", TransferUpload, 0, 0, 0, true, fs)
  335. t1.BytesReceived = 123
  336. t2 := NewBaseTransfer(nil, c1, nil, "/p2", "/r2", TransferDownload, 0, 0, 0, true, fs)
  337. t2.BytesSent = 456
  338. c2 := NewBaseConnection("id2", ProtocolSSH, user, nil)
  339. fakeConn2 := &fakeConnection{
  340. BaseConnection: c2,
  341. command: "md5sum",
  342. }
  343. c3 := NewBaseConnection("id3", ProtocolWebDAV, user, nil)
  344. fakeConn3 := &fakeConnection{
  345. BaseConnection: c3,
  346. command: "PROPFIND",
  347. }
  348. t3 := NewBaseTransfer(nil, c3, nil, "/p2", "/r2", TransferDownload, 0, 0, 0, true, fs)
  349. Connections.Add(fakeConn1)
  350. Connections.Add(fakeConn2)
  351. Connections.Add(fakeConn3)
  352. stats := Connections.GetStats()
  353. assert.Len(t, stats, 3)
  354. for _, stat := range stats {
  355. assert.Equal(t, stat.Username, username)
  356. assert.True(t, strings.HasPrefix(stat.GetConnectionInfo(), stat.Protocol))
  357. assert.True(t, strings.HasPrefix(stat.GetConnectionDuration(), "00:"))
  358. if stat.ConnectionID == "SFTP_id1" {
  359. assert.Len(t, stat.Transfers, 2)
  360. assert.Greater(t, len(stat.GetTransfersAsString()), 0)
  361. for _, tr := range stat.Transfers {
  362. if tr.OperationType == operationDownload {
  363. assert.True(t, strings.HasPrefix(tr.getConnectionTransferAsString(), "DL"))
  364. } else if tr.OperationType == operationUpload {
  365. assert.True(t, strings.HasPrefix(tr.getConnectionTransferAsString(), "UL"))
  366. }
  367. }
  368. } else if stat.ConnectionID == "DAV_id3" {
  369. assert.Len(t, stat.Transfers, 1)
  370. assert.Greater(t, len(stat.GetTransfersAsString()), 0)
  371. } else {
  372. assert.Equal(t, 0, len(stat.GetTransfersAsString()))
  373. }
  374. }
  375. err := t1.Close()
  376. assert.NoError(t, err)
  377. err = t2.Close()
  378. assert.NoError(t, err)
  379. err = fakeConn3.SignalTransfersAbort()
  380. assert.NoError(t, err)
  381. assert.Equal(t, int32(1), atomic.LoadInt32(&t3.AbortTransfer))
  382. err = t3.Close()
  383. assert.NoError(t, err)
  384. err = fakeConn3.SignalTransfersAbort()
  385. assert.Error(t, err)
  386. Connections.Remove(fakeConn1.GetID())
  387. stats = Connections.GetStats()
  388. assert.Len(t, stats, 2)
  389. assert.Equal(t, fakeConn3.GetID(), stats[0].ConnectionID)
  390. assert.Equal(t, fakeConn2.GetID(), stats[1].ConnectionID)
  391. Connections.Remove(fakeConn2.GetID())
  392. stats = Connections.GetStats()
  393. assert.Len(t, stats, 1)
  394. assert.Equal(t, fakeConn3.GetID(), stats[0].ConnectionID)
  395. Connections.Remove(fakeConn3.GetID())
  396. stats = Connections.GetStats()
  397. assert.Len(t, stats, 0)
  398. }
  399. func TestQuotaScans(t *testing.T) {
  400. username := "username"
  401. assert.True(t, QuotaScans.AddUserQuotaScan(username))
  402. assert.False(t, QuotaScans.AddUserQuotaScan(username))
  403. if assert.Len(t, QuotaScans.GetUsersQuotaScans(), 1) {
  404. assert.Equal(t, QuotaScans.GetUsersQuotaScans()[0].Username, username)
  405. }
  406. assert.True(t, QuotaScans.RemoveUserQuotaScan(username))
  407. assert.False(t, QuotaScans.RemoveUserQuotaScan(username))
  408. assert.Len(t, QuotaScans.GetUsersQuotaScans(), 0)
  409. folderName := "/folder"
  410. assert.True(t, QuotaScans.AddVFolderQuotaScan(folderName))
  411. assert.False(t, QuotaScans.AddVFolderQuotaScan(folderName))
  412. if assert.Len(t, QuotaScans.GetVFoldersQuotaScans(), 1) {
  413. assert.Equal(t, QuotaScans.GetVFoldersQuotaScans()[0].MappedPath, folderName)
  414. }
  415. assert.True(t, QuotaScans.RemoveVFolderQuotaScan(folderName))
  416. assert.False(t, QuotaScans.RemoveVFolderQuotaScan(folderName))
  417. assert.Len(t, QuotaScans.GetVFoldersQuotaScans(), 0)
  418. }
  419. func TestProxyProtocolVersion(t *testing.T) {
  420. c := Configuration{
  421. ProxyProtocol: 1,
  422. }
  423. proxyListener, err := c.GetProxyListener(nil)
  424. assert.NoError(t, err)
  425. assert.Nil(t, proxyListener.Policy)
  426. c.ProxyProtocol = 2
  427. proxyListener, err = c.GetProxyListener(nil)
  428. assert.NoError(t, err)
  429. assert.NotNil(t, proxyListener.Policy)
  430. c.ProxyProtocol = 1
  431. c.ProxyAllowed = []string{"invalid"}
  432. _, err = c.GetProxyListener(nil)
  433. assert.Error(t, err)
  434. c.ProxyProtocol = 2
  435. _, err = c.GetProxyListener(nil)
  436. assert.Error(t, err)
  437. }
  438. func TestProxyProtocol(t *testing.T) {
  439. httpClient := httpclient.GetHTTPClient()
  440. resp, err := httpClient.Get(fmt.Sprintf("http://%v", httpProxyAddr))
  441. if assert.NoError(t, err) {
  442. defer resp.Body.Close()
  443. assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
  444. }
  445. }
  446. func TestPostConnectHook(t *testing.T) {
  447. Config.PostConnectHook = ""
  448. remoteAddr := &net.IPAddr{
  449. IP: net.ParseIP("127.0.0.1"),
  450. Zone: "",
  451. }
  452. assert.NoError(t, Config.ExecutePostConnectHook(remoteAddr.String(), ProtocolFTP))
  453. Config.PostConnectHook = "http://foo\x7f.com/"
  454. assert.Error(t, Config.ExecutePostConnectHook(remoteAddr.String(), ProtocolSFTP))
  455. Config.PostConnectHook = "http://invalid:1234/"
  456. assert.Error(t, Config.ExecutePostConnectHook(remoteAddr.String(), ProtocolSFTP))
  457. Config.PostConnectHook = fmt.Sprintf("http://%v/404", httpAddr)
  458. assert.Error(t, Config.ExecutePostConnectHook(remoteAddr.String(), ProtocolFTP))
  459. Config.PostConnectHook = fmt.Sprintf("http://%v", httpAddr)
  460. assert.NoError(t, Config.ExecutePostConnectHook(remoteAddr.String(), ProtocolFTP))
  461. Config.PostConnectHook = "invalid"
  462. assert.Error(t, Config.ExecutePostConnectHook(remoteAddr.String(), ProtocolFTP))
  463. if runtime.GOOS == osWindows {
  464. Config.PostConnectHook = "C:\\bad\\command"
  465. assert.Error(t, Config.ExecutePostConnectHook(remoteAddr.String(), ProtocolSFTP))
  466. } else {
  467. Config.PostConnectHook = "/invalid/path"
  468. assert.Error(t, Config.ExecutePostConnectHook(remoteAddr.String(), ProtocolSFTP))
  469. hookCmd, err := exec.LookPath("true")
  470. assert.NoError(t, err)
  471. Config.PostConnectHook = hookCmd
  472. assert.NoError(t, Config.ExecutePostConnectHook(remoteAddr.String(), ProtocolSFTP))
  473. }
  474. Config.PostConnectHook = ""
  475. }