common_test.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516
  1. package common
  2. import (
  3. "fmt"
  4. "net"
  5. "net/http"
  6. "os"
  7. "os/exec"
  8. "runtime"
  9. "strings"
  10. "sync/atomic"
  11. "testing"
  12. "time"
  13. "github.com/rs/zerolog"
  14. "github.com/spf13/viper"
  15. "github.com/stretchr/testify/assert"
  16. "github.com/drakkan/sftpgo/dataprovider"
  17. "github.com/drakkan/sftpgo/httpclient"
  18. "github.com/drakkan/sftpgo/logger"
  19. "github.com/drakkan/sftpgo/vfs"
  20. )
  21. const (
  22. logSenderTest = "common_test"
  23. httpAddr = "127.0.0.1:9999"
  24. httpProxyAddr = "127.0.0.1:7777"
  25. configDir = ".."
  26. osWindows = "windows"
  27. userTestUsername = "common_test_username"
  28. userTestPwd = "common_test_pwd"
  29. )
  30. type providerConf struct {
  31. Config dataprovider.Config `json:"data_provider" mapstructure:"data_provider"`
  32. }
  33. type fakeConnection struct {
  34. *BaseConnection
  35. command string
  36. }
  37. func (c *fakeConnection) AddUser(user dataprovider.User) error {
  38. fs, err := user.GetFilesystem(c.GetID())
  39. if err != nil {
  40. return err
  41. }
  42. c.BaseConnection.User = user
  43. c.BaseConnection.Fs = fs
  44. return nil
  45. }
  46. func (c *fakeConnection) Disconnect() error {
  47. Connections.Remove(c.GetID())
  48. return nil
  49. }
  50. func (c *fakeConnection) GetClientVersion() string {
  51. return ""
  52. }
  53. func (c *fakeConnection) GetCommand() string {
  54. return c.command
  55. }
  56. func (c *fakeConnection) GetRemoteAddress() string {
  57. return ""
  58. }
  59. type customNetConn struct {
  60. net.Conn
  61. id string
  62. isClosed bool
  63. }
  64. func (c *customNetConn) Close() error {
  65. Connections.RemoveSSHConnection(c.id)
  66. c.isClosed = true
  67. return c.Conn.Close()
  68. }
  69. func TestMain(m *testing.M) {
  70. logfilePath := "common_test.log"
  71. logger.InitLogger(logfilePath, 5, 1, 28, false, zerolog.DebugLevel)
  72. viper.SetEnvPrefix("sftpgo")
  73. replacer := strings.NewReplacer(".", "__")
  74. viper.SetEnvKeyReplacer(replacer)
  75. viper.SetConfigName("sftpgo")
  76. viper.AutomaticEnv()
  77. viper.AllowEmptyEnv(true)
  78. driver, err := initializeDataprovider(-1)
  79. if err != nil {
  80. logger.WarnToConsole("error initializing data provider: %v", err)
  81. os.Exit(1)
  82. }
  83. logger.InfoToConsole("Starting COMMON tests, provider: %v", driver)
  84. Initialize(Configuration{})
  85. httpConfig := httpclient.Config{
  86. Timeout: 5,
  87. }
  88. httpConfig.Initialize(configDir)
  89. go func() {
  90. // start a test HTTP server to receive action notifications
  91. http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
  92. fmt.Fprintf(w, "OK\n")
  93. })
  94. http.HandleFunc("/404", func(w http.ResponseWriter, r *http.Request) {
  95. w.WriteHeader(http.StatusNotFound)
  96. fmt.Fprintf(w, "Not found\n")
  97. })
  98. if err := http.ListenAndServe(httpAddr, nil); err != nil {
  99. logger.ErrorToConsole("could not start HTTP notification server: %v", err)
  100. os.Exit(1)
  101. }
  102. }()
  103. go func() {
  104. Config.ProxyProtocol = 2
  105. listener, err := net.Listen("tcp", httpProxyAddr)
  106. if err != nil {
  107. logger.ErrorToConsole("error creating listener for proxy protocol server: %v", err)
  108. os.Exit(1)
  109. }
  110. proxyListener, err := Config.GetProxyListener(listener)
  111. if err != nil {
  112. logger.ErrorToConsole("error creating proxy protocol listener: %v", err)
  113. os.Exit(1)
  114. }
  115. Config.ProxyProtocol = 0
  116. s := &http.Server{}
  117. if err := s.Serve(proxyListener); err != nil {
  118. logger.ErrorToConsole("could not start HTTP proxy protocol server: %v", err)
  119. os.Exit(1)
  120. }
  121. }()
  122. waitTCPListening(httpAddr)
  123. waitTCPListening(httpProxyAddr)
  124. exitCode := m.Run()
  125. os.Remove(logfilePath) //nolint:errcheck
  126. os.Exit(exitCode)
  127. }
  128. func waitTCPListening(address string) {
  129. for {
  130. conn, err := net.Dial("tcp", address)
  131. if err != nil {
  132. logger.WarnToConsole("tcp server %v not listening: %v\n", address, err)
  133. time.Sleep(100 * time.Millisecond)
  134. continue
  135. }
  136. logger.InfoToConsole("tcp server %v now listening\n", address)
  137. conn.Close()
  138. break
  139. }
  140. }
  141. func initializeDataprovider(trackQuota int) (string, error) {
  142. configDir := ".."
  143. viper.AddConfigPath(configDir)
  144. if err := viper.ReadInConfig(); err != nil {
  145. return "", err
  146. }
  147. var cfg providerConf
  148. if err := viper.Unmarshal(&cfg); err != nil {
  149. return "", err
  150. }
  151. if trackQuota >= 0 && trackQuota <= 2 {
  152. cfg.Config.TrackQuota = trackQuota
  153. }
  154. return cfg.Config.Driver, dataprovider.Initialize(cfg.Config, configDir)
  155. }
  156. func closeDataprovider() error {
  157. return dataprovider.Close()
  158. }
  159. func TestSSHConnections(t *testing.T) {
  160. conn1, conn2 := net.Pipe()
  161. now := time.Now()
  162. sshConn1 := NewSSHConnection("id1", conn1)
  163. sshConn2 := NewSSHConnection("id2", conn2)
  164. assert.Equal(t, "id1", sshConn1.GetID())
  165. assert.Equal(t, "id2", sshConn2.GetID())
  166. sshConn1.UpdateLastActivity()
  167. assert.GreaterOrEqual(t, sshConn1.GetLastActivity().UnixNano(), now.UnixNano())
  168. Connections.AddSSHConnection(sshConn1)
  169. Connections.AddSSHConnection(sshConn2)
  170. Connections.RLock()
  171. assert.Len(t, Connections.sshConnections, 2)
  172. Connections.RUnlock()
  173. Connections.RemoveSSHConnection(sshConn1.id)
  174. Connections.RLock()
  175. assert.Len(t, Connections.sshConnections, 1)
  176. Connections.RUnlock()
  177. Connections.RemoveSSHConnection(sshConn1.id)
  178. Connections.RLock()
  179. assert.Len(t, Connections.sshConnections, 1)
  180. Connections.RUnlock()
  181. Connections.RemoveSSHConnection(sshConn2.id)
  182. Connections.RLock()
  183. assert.Len(t, Connections.sshConnections, 0)
  184. Connections.RUnlock()
  185. assert.NoError(t, sshConn1.Close())
  186. assert.NoError(t, sshConn2.Close())
  187. }
  188. func TestIdleConnections(t *testing.T) {
  189. configCopy := Config
  190. Config.IdleTimeout = 1
  191. Initialize(Config)
  192. conn1, conn2 := net.Pipe()
  193. customConn1 := &customNetConn{
  194. Conn: conn1,
  195. id: "id1",
  196. }
  197. customConn2 := &customNetConn{
  198. Conn: conn2,
  199. id: "id2",
  200. }
  201. sshConn1 := NewSSHConnection(customConn1.id, customConn1)
  202. sshConn2 := NewSSHConnection(customConn2.id, customConn2)
  203. username := "test_user"
  204. user := dataprovider.User{
  205. Username: username,
  206. }
  207. c := NewBaseConnection(sshConn1.id+"_1", ProtocolSFTP, user, nil)
  208. c.lastActivity = time.Now().Add(-24 * time.Hour).UnixNano()
  209. fakeConn := &fakeConnection{
  210. BaseConnection: c,
  211. }
  212. // both ssh connections are expired but they should get removed only
  213. // if there is no associated connection
  214. sshConn1.lastActivity = c.lastActivity
  215. sshConn2.lastActivity = c.lastActivity
  216. Connections.AddSSHConnection(sshConn1)
  217. Connections.Add(fakeConn)
  218. assert.Equal(t, Connections.GetActiveSessions(username), 1)
  219. c = NewBaseConnection(sshConn2.id+"_1", ProtocolSSH, user, nil)
  220. fakeConn = &fakeConnection{
  221. BaseConnection: c,
  222. }
  223. Connections.AddSSHConnection(sshConn2)
  224. Connections.Add(fakeConn)
  225. assert.Equal(t, Connections.GetActiveSessions(username), 2)
  226. cFTP := NewBaseConnection("id2", ProtocolFTP, dataprovider.User{}, nil)
  227. cFTP.lastActivity = time.Now().UnixNano()
  228. fakeConn = &fakeConnection{
  229. BaseConnection: cFTP,
  230. }
  231. Connections.Add(fakeConn)
  232. assert.Equal(t, Connections.GetActiveSessions(username), 2)
  233. assert.Len(t, Connections.GetStats(), 3)
  234. Connections.RLock()
  235. assert.Len(t, Connections.sshConnections, 2)
  236. Connections.RUnlock()
  237. startIdleTimeoutTicker(100 * time.Millisecond)
  238. assert.Eventually(t, func() bool { return Connections.GetActiveSessions(username) == 1 }, 1*time.Second, 200*time.Millisecond)
  239. assert.Eventually(t, func() bool {
  240. Connections.RLock()
  241. defer Connections.RUnlock()
  242. return len(Connections.sshConnections) == 1
  243. }, 1*time.Second, 200*time.Millisecond)
  244. stopIdleTimeoutTicker()
  245. assert.Len(t, Connections.GetStats(), 2)
  246. c.lastActivity = time.Now().Add(-24 * time.Hour).UnixNano()
  247. cFTP.lastActivity = time.Now().Add(-24 * time.Hour).UnixNano()
  248. sshConn2.lastActivity = c.lastActivity
  249. startIdleTimeoutTicker(100 * time.Millisecond)
  250. assert.Eventually(t, func() bool { return len(Connections.GetStats()) == 0 }, 1*time.Second, 200*time.Millisecond)
  251. assert.Eventually(t, func() bool {
  252. Connections.RLock()
  253. defer Connections.RUnlock()
  254. return len(Connections.sshConnections) == 0
  255. }, 1*time.Second, 200*time.Millisecond)
  256. stopIdleTimeoutTicker()
  257. assert.True(t, customConn1.isClosed)
  258. assert.True(t, customConn2.isClosed)
  259. Config = configCopy
  260. }
  261. func TestCloseConnection(t *testing.T) {
  262. c := NewBaseConnection("id", ProtocolSFTP, dataprovider.User{}, nil)
  263. fakeConn := &fakeConnection{
  264. BaseConnection: c,
  265. }
  266. Connections.Add(fakeConn)
  267. assert.Len(t, Connections.GetStats(), 1)
  268. res := Connections.Close(fakeConn.GetID())
  269. assert.True(t, res)
  270. assert.Eventually(t, func() bool { return len(Connections.GetStats()) == 0 }, 300*time.Millisecond, 50*time.Millisecond)
  271. res = Connections.Close(fakeConn.GetID())
  272. assert.False(t, res)
  273. Connections.Remove(fakeConn.GetID())
  274. }
  275. func TestSwapConnection(t *testing.T) {
  276. c := NewBaseConnection("id", ProtocolFTP, dataprovider.User{}, nil)
  277. fakeConn := &fakeConnection{
  278. BaseConnection: c,
  279. }
  280. Connections.Add(fakeConn)
  281. if assert.Len(t, Connections.GetStats(), 1) {
  282. assert.Equal(t, "", Connections.GetStats()[0].Username)
  283. }
  284. c = NewBaseConnection("id", ProtocolFTP, dataprovider.User{
  285. Username: userTestUsername,
  286. }, nil)
  287. fakeConn = &fakeConnection{
  288. BaseConnection: c,
  289. }
  290. err := Connections.Swap(fakeConn)
  291. assert.NoError(t, err)
  292. if assert.Len(t, Connections.GetStats(), 1) {
  293. assert.Equal(t, userTestUsername, Connections.GetStats()[0].Username)
  294. }
  295. res := Connections.Close(fakeConn.GetID())
  296. assert.True(t, res)
  297. assert.Eventually(t, func() bool { return len(Connections.GetStats()) == 0 }, 300*time.Millisecond, 50*time.Millisecond)
  298. err = Connections.Swap(fakeConn)
  299. assert.Error(t, err)
  300. }
  301. func TestAtomicUpload(t *testing.T) {
  302. configCopy := Config
  303. Config.UploadMode = UploadModeStandard
  304. assert.False(t, Config.IsAtomicUploadEnabled())
  305. Config.UploadMode = UploadModeAtomic
  306. assert.True(t, Config.IsAtomicUploadEnabled())
  307. Config.UploadMode = UploadModeAtomicWithResume
  308. assert.True(t, Config.IsAtomicUploadEnabled())
  309. Config = configCopy
  310. }
  311. func TestConnectionStatus(t *testing.T) {
  312. username := "test_user"
  313. user := dataprovider.User{
  314. Username: username,
  315. }
  316. fs := vfs.NewOsFs("", os.TempDir(), nil)
  317. c1 := NewBaseConnection("id1", ProtocolSFTP, user, fs)
  318. fakeConn1 := &fakeConnection{
  319. BaseConnection: c1,
  320. }
  321. t1 := NewBaseTransfer(nil, c1, nil, "/p1", "/r1", TransferUpload, 0, 0, 0, true, fs)
  322. t1.BytesReceived = 123
  323. t2 := NewBaseTransfer(nil, c1, nil, "/p2", "/r2", TransferDownload, 0, 0, 0, true, fs)
  324. t2.BytesSent = 456
  325. c2 := NewBaseConnection("id2", ProtocolSSH, user, nil)
  326. fakeConn2 := &fakeConnection{
  327. BaseConnection: c2,
  328. command: "md5sum",
  329. }
  330. c3 := NewBaseConnection("id3", ProtocolWebDAV, user, nil)
  331. fakeConn3 := &fakeConnection{
  332. BaseConnection: c3,
  333. command: "PROPFIND",
  334. }
  335. t3 := NewBaseTransfer(nil, c3, nil, "/p2", "/r2", TransferDownload, 0, 0, 0, true, fs)
  336. Connections.Add(fakeConn1)
  337. Connections.Add(fakeConn2)
  338. Connections.Add(fakeConn3)
  339. stats := Connections.GetStats()
  340. assert.Len(t, stats, 3)
  341. for _, stat := range stats {
  342. assert.Equal(t, stat.Username, username)
  343. assert.True(t, strings.HasPrefix(stat.GetConnectionInfo(), stat.Protocol))
  344. assert.True(t, strings.HasPrefix(stat.GetConnectionDuration(), "00:"))
  345. if stat.ConnectionID == "SFTP_id1" {
  346. assert.Len(t, stat.Transfers, 2)
  347. assert.Greater(t, len(stat.GetTransfersAsString()), 0)
  348. for _, tr := range stat.Transfers {
  349. if tr.OperationType == operationDownload {
  350. assert.True(t, strings.HasPrefix(tr.getConnectionTransferAsString(), "DL"))
  351. } else if tr.OperationType == operationUpload {
  352. assert.True(t, strings.HasPrefix(tr.getConnectionTransferAsString(), "UL"))
  353. }
  354. }
  355. } else if stat.ConnectionID == "DAV_id3" {
  356. assert.Len(t, stat.Transfers, 1)
  357. assert.Greater(t, len(stat.GetTransfersAsString()), 0)
  358. } else {
  359. assert.Equal(t, 0, len(stat.GetTransfersAsString()))
  360. }
  361. }
  362. err := t1.Close()
  363. assert.NoError(t, err)
  364. err = t2.Close()
  365. assert.NoError(t, err)
  366. err = fakeConn3.SignalTransfersAbort()
  367. assert.NoError(t, err)
  368. assert.Equal(t, int32(1), atomic.LoadInt32(&t3.AbortTransfer))
  369. err = t3.Close()
  370. assert.NoError(t, err)
  371. err = fakeConn3.SignalTransfersAbort()
  372. assert.Error(t, err)
  373. Connections.Remove(fakeConn1.GetID())
  374. Connections.Remove(fakeConn2.GetID())
  375. Connections.Remove(fakeConn3.GetID())
  376. stats = Connections.GetStats()
  377. assert.Len(t, stats, 0)
  378. }
  379. func TestQuotaScans(t *testing.T) {
  380. username := "username"
  381. assert.True(t, QuotaScans.AddUserQuotaScan(username))
  382. assert.False(t, QuotaScans.AddUserQuotaScan(username))
  383. if assert.Len(t, QuotaScans.GetUsersQuotaScans(), 1) {
  384. assert.Equal(t, QuotaScans.GetUsersQuotaScans()[0].Username, username)
  385. }
  386. assert.True(t, QuotaScans.RemoveUserQuotaScan(username))
  387. assert.False(t, QuotaScans.RemoveUserQuotaScan(username))
  388. assert.Len(t, QuotaScans.GetUsersQuotaScans(), 0)
  389. folderName := "/folder"
  390. assert.True(t, QuotaScans.AddVFolderQuotaScan(folderName))
  391. assert.False(t, QuotaScans.AddVFolderQuotaScan(folderName))
  392. if assert.Len(t, QuotaScans.GetVFoldersQuotaScans(), 1) {
  393. assert.Equal(t, QuotaScans.GetVFoldersQuotaScans()[0].MappedPath, folderName)
  394. }
  395. assert.True(t, QuotaScans.RemoveVFolderQuotaScan(folderName))
  396. assert.False(t, QuotaScans.RemoveVFolderQuotaScan(folderName))
  397. assert.Len(t, QuotaScans.GetVFoldersQuotaScans(), 0)
  398. }
  399. func TestProxyProtocolVersion(t *testing.T) {
  400. c := Configuration{
  401. ProxyProtocol: 1,
  402. }
  403. proxyListener, err := c.GetProxyListener(nil)
  404. assert.NoError(t, err)
  405. assert.Nil(t, proxyListener.Policy)
  406. c.ProxyProtocol = 2
  407. proxyListener, err = c.GetProxyListener(nil)
  408. assert.NoError(t, err)
  409. assert.NotNil(t, proxyListener.Policy)
  410. c.ProxyProtocol = 1
  411. c.ProxyAllowed = []string{"invalid"}
  412. _, err = c.GetProxyListener(nil)
  413. assert.Error(t, err)
  414. c.ProxyProtocol = 2
  415. _, err = c.GetProxyListener(nil)
  416. assert.Error(t, err)
  417. }
  418. func TestProxyProtocol(t *testing.T) {
  419. httpClient := httpclient.GetHTTPClient()
  420. resp, err := httpClient.Get(fmt.Sprintf("http://%v", httpProxyAddr))
  421. if assert.NoError(t, err) {
  422. defer resp.Body.Close()
  423. assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
  424. }
  425. }
  426. func TestPostConnectHook(t *testing.T) {
  427. Config.PostConnectHook = ""
  428. remoteAddr := &net.IPAddr{
  429. IP: net.ParseIP("127.0.0.1"),
  430. Zone: "",
  431. }
  432. assert.NoError(t, Config.ExecutePostConnectHook(remoteAddr.String(), ProtocolFTP))
  433. Config.PostConnectHook = "http://foo\x7f.com/"
  434. assert.Error(t, Config.ExecutePostConnectHook(remoteAddr.String(), ProtocolSFTP))
  435. Config.PostConnectHook = "http://invalid:1234/"
  436. assert.Error(t, Config.ExecutePostConnectHook(remoteAddr.String(), ProtocolSFTP))
  437. Config.PostConnectHook = fmt.Sprintf("http://%v/404", httpAddr)
  438. assert.Error(t, Config.ExecutePostConnectHook(remoteAddr.String(), ProtocolFTP))
  439. Config.PostConnectHook = fmt.Sprintf("http://%v", httpAddr)
  440. assert.NoError(t, Config.ExecutePostConnectHook(remoteAddr.String(), ProtocolFTP))
  441. Config.PostConnectHook = "invalid"
  442. assert.Error(t, Config.ExecutePostConnectHook(remoteAddr.String(), ProtocolFTP))
  443. if runtime.GOOS == osWindows {
  444. Config.PostConnectHook = "C:\\bad\\command"
  445. assert.Error(t, Config.ExecutePostConnectHook(remoteAddr.String(), ProtocolSFTP))
  446. } else {
  447. Config.PostConnectHook = "/invalid/path"
  448. assert.Error(t, Config.ExecutePostConnectHook(remoteAddr.String(), ProtocolSFTP))
  449. hookCmd, err := exec.LookPath("true")
  450. assert.NoError(t, err)
  451. Config.PostConnectHook = hookCmd
  452. assert.NoError(t, Config.ExecutePostConnectHook(remoteAddr.String(), ProtocolSFTP))
  453. }
  454. Config.PostConnectHook = ""
  455. }