common_test.go 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. package common
  2. import (
  3. "fmt"
  4. "net"
  5. "net/http"
  6. "os"
  7. "strings"
  8. "testing"
  9. "time"
  10. "github.com/rs/zerolog"
  11. "github.com/spf13/viper"
  12. "github.com/stretchr/testify/assert"
  13. "github.com/drakkan/sftpgo/dataprovider"
  14. "github.com/drakkan/sftpgo/httpclient"
  15. "github.com/drakkan/sftpgo/logger"
  16. )
  17. const (
  18. logSender = "common_test"
  19. httpAddr = "127.0.0.1:9999"
  20. httpProxyAddr = "127.0.0.1:7777"
  21. configDir = ".."
  22. osWindows = "windows"
  23. userTestUsername = "common_test_username"
  24. userTestPwd = "common_test_pwd"
  25. )
  26. type providerConf struct {
  27. Config dataprovider.Config `json:"data_provider" mapstructure:"data_provider"`
  28. }
  29. type fakeConnection struct {
  30. *BaseConnection
  31. sshCommand string
  32. }
  33. func (c *fakeConnection) Disconnect() error {
  34. Connections.Remove(c)
  35. return nil
  36. }
  37. func (c *fakeConnection) GetClientVersion() string {
  38. return ""
  39. }
  40. func (c *fakeConnection) GetCommand() string {
  41. return c.sshCommand
  42. }
  43. func (c *fakeConnection) GetRemoteAddress() string {
  44. return ""
  45. }
  46. func (c *fakeConnection) SetConnDeadline() {}
  47. func TestMain(m *testing.M) {
  48. logfilePath := "common_test.log"
  49. logger.InitLogger(logfilePath, 5, 1, 28, false, zerolog.DebugLevel)
  50. viper.SetEnvPrefix("sftpgo")
  51. replacer := strings.NewReplacer(".", "__")
  52. viper.SetEnvKeyReplacer(replacer)
  53. viper.SetConfigName("sftpgo")
  54. viper.AutomaticEnv()
  55. viper.AllowEmptyEnv(true)
  56. driver, err := initializeDataprovider(-1)
  57. if err != nil {
  58. logger.WarnToConsole("error initializing data provider: %v", err)
  59. os.Exit(1)
  60. }
  61. logger.InfoToConsole("Starting COMMON tests, provider: %v", driver)
  62. Initialize(Configuration{})
  63. httpConfig := httpclient.Config{
  64. Timeout: 5,
  65. }
  66. httpConfig.Initialize(configDir)
  67. go func() {
  68. // start a test HTTP server to receive action notifications
  69. http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
  70. fmt.Fprintf(w, "OK\n")
  71. })
  72. http.HandleFunc("/404", func(w http.ResponseWriter, r *http.Request) {
  73. w.WriteHeader(http.StatusNotFound)
  74. fmt.Fprintf(w, "Not found\n")
  75. })
  76. if err := http.ListenAndServe(httpAddr, nil); err != nil {
  77. logger.ErrorToConsole("could not start HTTP notification server: %v", err)
  78. os.Exit(1)
  79. }
  80. }()
  81. go func() {
  82. Config.ProxyProtocol = 2
  83. listener, err := net.Listen("tcp", httpProxyAddr)
  84. if err != nil {
  85. logger.ErrorToConsole("error creating listener for proxy protocol server: %v", err)
  86. os.Exit(1)
  87. }
  88. proxyListener, err := Config.GetProxyListener(listener)
  89. if err != nil {
  90. logger.ErrorToConsole("error creating proxy protocol listener: %v", err)
  91. os.Exit(1)
  92. }
  93. Config.ProxyProtocol = 0
  94. s := &http.Server{}
  95. if err := s.Serve(proxyListener); err != nil {
  96. logger.ErrorToConsole("could not start HTTP proxy protocol server: %v", err)
  97. os.Exit(1)
  98. }
  99. }()
  100. waitTCPListening(httpAddr)
  101. waitTCPListening(httpProxyAddr)
  102. exitCode := m.Run()
  103. os.Remove(logfilePath) //nolint:errcheck
  104. os.Exit(exitCode)
  105. }
  106. func waitTCPListening(address string) {
  107. for {
  108. conn, err := net.Dial("tcp", address)
  109. if err != nil {
  110. logger.WarnToConsole("tcp server %v not listening: %v\n", address, err)
  111. time.Sleep(100 * time.Millisecond)
  112. continue
  113. }
  114. logger.InfoToConsole("tcp server %v now listening\n", address)
  115. conn.Close()
  116. break
  117. }
  118. }
  119. func initializeDataprovider(trackQuota int) (string, error) {
  120. configDir := ".."
  121. viper.AddConfigPath(configDir)
  122. if err := viper.ReadInConfig(); err != nil {
  123. return "", err
  124. }
  125. var cfg providerConf
  126. if err := viper.Unmarshal(&cfg); err != nil {
  127. return "", err
  128. }
  129. if trackQuota >= 0 && trackQuota <= 2 {
  130. cfg.Config.TrackQuota = trackQuota
  131. }
  132. return cfg.Config.Driver, dataprovider.Initialize(cfg.Config, configDir)
  133. }
  134. func closeDataprovider() error {
  135. return dataprovider.Close()
  136. }
  137. func TestIdleConnections(t *testing.T) {
  138. configCopy := Config
  139. Config.IdleTimeout = 1
  140. Initialize(Config)
  141. username := "test_user"
  142. user := dataprovider.User{
  143. Username: username,
  144. }
  145. c := NewBaseConnection("id", ProtocolSFTP, user, nil)
  146. c.lastActivity = time.Now().Add(-24 * time.Hour).UnixNano()
  147. fakeConn := &fakeConnection{
  148. BaseConnection: c,
  149. }
  150. Connections.Add(fakeConn)
  151. assert.Equal(t, Connections.GetActiveSessions(username), 1)
  152. startIdleTimeoutTicker(100 * time.Millisecond)
  153. assert.Eventually(t, func() bool { return Connections.GetActiveSessions(username) == 0 }, 1*time.Second, 200*time.Millisecond)
  154. stopIdleTimeoutTicker()
  155. Config = configCopy
  156. }
  157. func TestCloseConnection(t *testing.T) {
  158. c := NewBaseConnection("id", ProtocolSFTP, dataprovider.User{}, nil)
  159. fakeConn := &fakeConnection{
  160. BaseConnection: c,
  161. }
  162. Connections.Add(fakeConn)
  163. assert.Len(t, Connections.GetStats(), 1)
  164. res := Connections.Close(fakeConn.GetID())
  165. assert.True(t, res)
  166. assert.Eventually(t, func() bool { return len(Connections.GetStats()) == 0 }, 300*time.Millisecond, 50*time.Millisecond)
  167. res = Connections.Close(fakeConn.GetID())
  168. assert.False(t, res)
  169. Connections.Remove(fakeConn)
  170. }
  171. func TestAtomicUpload(t *testing.T) {
  172. configCopy := Config
  173. Config.UploadMode = UploadModeStandard
  174. assert.False(t, Config.IsAtomicUploadEnabled())
  175. Config.UploadMode = UploadModeAtomic
  176. assert.True(t, Config.IsAtomicUploadEnabled())
  177. Config.UploadMode = UploadModeAtomicWithResume
  178. assert.True(t, Config.IsAtomicUploadEnabled())
  179. Config = configCopy
  180. }
  181. func TestConnectionStatus(t *testing.T) {
  182. username := "test_user"
  183. user := dataprovider.User{
  184. Username: username,
  185. }
  186. c1 := NewBaseConnection("id1", ProtocolSFTP, user, nil)
  187. fakeConn1 := &fakeConnection{
  188. BaseConnection: c1,
  189. }
  190. t1 := NewBaseTransfer(nil, c1, nil, "/p1", "/r1", TransferUpload, 0, 0, true)
  191. t1.BytesReceived = 123
  192. t2 := NewBaseTransfer(nil, c1, nil, "/p2", "/r2", TransferDownload, 0, 0, true)
  193. t2.BytesSent = 456
  194. c2 := NewBaseConnection("id2", ProtocolSSH, user, nil)
  195. fakeConn2 := &fakeConnection{
  196. BaseConnection: c2,
  197. sshCommand: "md5sum",
  198. }
  199. Connections.Add(fakeConn1)
  200. Connections.Add(fakeConn2)
  201. stats := Connections.GetStats()
  202. assert.Len(t, stats, 2)
  203. for _, stat := range stats {
  204. assert.Equal(t, stat.Username, username)
  205. assert.True(t, strings.HasPrefix(stat.GetConnectionInfo(), stat.Protocol))
  206. assert.True(t, strings.HasPrefix(stat.GetConnectionDuration(), "00:"))
  207. if stat.ConnectionID == "SFTP_id1" {
  208. assert.Len(t, stat.Transfers, 2)
  209. assert.Greater(t, len(stat.GetTransfersAsString()), 0)
  210. for _, tr := range stat.Transfers {
  211. if tr.OperationType == operationDownload {
  212. assert.True(t, strings.HasPrefix(tr.getConnectionTransferAsString(), "DL"))
  213. } else if tr.OperationType == operationUpload {
  214. assert.True(t, strings.HasPrefix(tr.getConnectionTransferAsString(), "UL"))
  215. }
  216. }
  217. } else {
  218. assert.Equal(t, 0, len(stat.GetTransfersAsString()))
  219. }
  220. }
  221. err := t1.Close()
  222. assert.NoError(t, err)
  223. err = t2.Close()
  224. assert.NoError(t, err)
  225. Connections.Remove(fakeConn1)
  226. Connections.Remove(fakeConn2)
  227. stats = Connections.GetStats()
  228. assert.Len(t, stats, 0)
  229. }
  230. func TestQuotaScans(t *testing.T) {
  231. username := "username"
  232. assert.True(t, QuotaScans.AddUserQuotaScan(username))
  233. assert.False(t, QuotaScans.AddUserQuotaScan(username))
  234. if assert.Len(t, QuotaScans.GetUsersQuotaScans(), 1) {
  235. assert.Equal(t, QuotaScans.GetUsersQuotaScans()[0].Username, username)
  236. }
  237. assert.True(t, QuotaScans.RemoveUserQuotaScan(username))
  238. assert.False(t, QuotaScans.RemoveUserQuotaScan(username))
  239. assert.Len(t, QuotaScans.GetUsersQuotaScans(), 0)
  240. folderName := "/folder"
  241. assert.True(t, QuotaScans.AddVFolderQuotaScan(folderName))
  242. assert.False(t, QuotaScans.AddVFolderQuotaScan(folderName))
  243. if assert.Len(t, QuotaScans.GetVFoldersQuotaScans(), 1) {
  244. assert.Equal(t, QuotaScans.GetVFoldersQuotaScans()[0].MappedPath, folderName)
  245. }
  246. assert.True(t, QuotaScans.RemoveVFolderQuotaScan(folderName))
  247. assert.False(t, QuotaScans.RemoveVFolderQuotaScan(folderName))
  248. assert.Len(t, QuotaScans.GetVFoldersQuotaScans(), 0)
  249. }
  250. func TestProxyProtocolVersion(t *testing.T) {
  251. c := Configuration{
  252. ProxyProtocol: 1,
  253. }
  254. proxyListener, err := c.GetProxyListener(nil)
  255. assert.NoError(t, err)
  256. assert.Nil(t, proxyListener.Policy)
  257. c.ProxyProtocol = 2
  258. proxyListener, err = c.GetProxyListener(nil)
  259. assert.NoError(t, err)
  260. assert.NotNil(t, proxyListener.Policy)
  261. c.ProxyProtocol = 1
  262. c.ProxyAllowed = []string{"invalid"}
  263. _, err = c.GetProxyListener(nil)
  264. assert.Error(t, err)
  265. c.ProxyProtocol = 2
  266. _, err = c.GetProxyListener(nil)
  267. assert.Error(t, err)
  268. }
  269. func TestProxyProtocol(t *testing.T) {
  270. httpClient := httpclient.GetHTTPClient()
  271. resp, err := httpClient.Get(fmt.Sprintf("http://%v", httpProxyAddr))
  272. if assert.NoError(t, err) {
  273. defer resp.Body.Close()
  274. assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
  275. }
  276. }