defenderdb_test.go 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. // Copyright (C) 2019-2023 Nicola Murino
  2. //
  3. // This program is free software: you can redistribute it and/or modify
  4. // it under the terms of the GNU Affero General Public License as published
  5. // by the Free Software Foundation, version 3.
  6. //
  7. // This program is distributed in the hope that it will be useful,
  8. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  9. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  10. // GNU Affero General Public License for more details.
  11. //
  12. // You should have received a copy of the GNU Affero General Public License
  13. // along with this program. If not, see <https://www.gnu.org/licenses/>.
  14. package common
  15. import (
  16. "encoding/hex"
  17. "encoding/json"
  18. "os"
  19. "path/filepath"
  20. "testing"
  21. "time"
  22. "github.com/stretchr/testify/assert"
  23. "github.com/drakkan/sftpgo/v2/internal/dataprovider"
  24. "github.com/drakkan/sftpgo/v2/internal/util"
  25. )
  26. func TestBasicDbDefender(t *testing.T) {
  27. if !isDbDefenderSupported() {
  28. t.Skip("this test is not supported with the current database provider")
  29. }
  30. config := &DefenderConfig{
  31. Enabled: true,
  32. BanTime: 10,
  33. BanTimeIncrement: 2,
  34. Threshold: 5,
  35. ScoreInvalid: 2,
  36. ScoreValid: 1,
  37. ScoreLimitExceeded: 3,
  38. ObservationTime: 15,
  39. EntriesSoftLimit: 1,
  40. EntriesHardLimit: 10,
  41. SafeListFile: "slFile",
  42. BlockListFile: "blFile",
  43. }
  44. _, err := newDBDefender(config)
  45. assert.Error(t, err)
  46. bl := HostListFile{
  47. IPAddresses: []string{"172.16.1.1", "172.16.1.2"},
  48. CIDRNetworks: []string{"10.8.0.0/24"},
  49. }
  50. sl := HostListFile{
  51. IPAddresses: []string{"172.16.1.3", "172.16.1.4"},
  52. CIDRNetworks: []string{"192.168.8.0/24"},
  53. }
  54. blFile := filepath.Join(os.TempDir(), "bl.json")
  55. slFile := filepath.Join(os.TempDir(), "sl.json")
  56. data, err := json.Marshal(bl)
  57. assert.NoError(t, err)
  58. err = os.WriteFile(blFile, data, os.ModePerm)
  59. assert.NoError(t, err)
  60. data, err = json.Marshal(sl)
  61. assert.NoError(t, err)
  62. err = os.WriteFile(slFile, data, os.ModePerm)
  63. assert.NoError(t, err)
  64. config.BlockListFile = blFile
  65. _, err = newDBDefender(config)
  66. assert.Error(t, err)
  67. config.SafeListFile = slFile
  68. d, err := newDBDefender(config)
  69. assert.NoError(t, err)
  70. defender := d.(*dbDefender)
  71. assert.True(t, defender.IsBanned("172.16.1.1"))
  72. assert.False(t, defender.IsBanned("172.16.1.10"))
  73. assert.False(t, defender.IsBanned("10.8.1.3"))
  74. assert.True(t, defender.IsBanned("10.8.0.4"))
  75. assert.False(t, defender.IsBanned("invalid ip"))
  76. hosts, err := defender.GetHosts()
  77. assert.NoError(t, err)
  78. assert.Len(t, hosts, 0)
  79. _, err = defender.GetHost("10.8.0.3")
  80. assert.Error(t, err)
  81. defender.AddEvent("172.16.1.4", HostEventLoginFailed)
  82. defender.AddEvent("192.168.8.4", HostEventUserNotFound)
  83. defender.AddEvent("172.16.1.3", HostEventLimitExceeded)
  84. hosts, err = defender.GetHosts()
  85. assert.NoError(t, err)
  86. assert.Len(t, hosts, 0)
  87. assert.True(t, defender.getLastCleanup().IsZero())
  88. testIP := "123.45.67.89"
  89. defender.AddEvent(testIP, HostEventLoginFailed)
  90. lastCleanup := defender.getLastCleanup()
  91. assert.False(t, lastCleanup.IsZero())
  92. score, err := defender.GetScore(testIP)
  93. assert.NoError(t, err)
  94. assert.Equal(t, 1, score)
  95. hosts, err = defender.GetHosts()
  96. assert.NoError(t, err)
  97. if assert.Len(t, hosts, 1) {
  98. assert.Equal(t, 1, hosts[0].Score)
  99. assert.True(t, hosts[0].BanTime.IsZero())
  100. assert.Empty(t, hosts[0].GetBanTime())
  101. }
  102. host, err := defender.GetHost(testIP)
  103. assert.NoError(t, err)
  104. assert.Equal(t, 1, host.Score)
  105. assert.Empty(t, host.GetBanTime())
  106. banTime, err := defender.GetBanTime(testIP)
  107. assert.NoError(t, err)
  108. assert.Nil(t, banTime)
  109. defender.AddEvent(testIP, HostEventLimitExceeded)
  110. score, err = defender.GetScore(testIP)
  111. assert.NoError(t, err)
  112. assert.Equal(t, 4, score)
  113. hosts, err = defender.GetHosts()
  114. assert.NoError(t, err)
  115. if assert.Len(t, hosts, 1) {
  116. assert.Equal(t, 4, hosts[0].Score)
  117. assert.True(t, hosts[0].BanTime.IsZero())
  118. assert.Empty(t, hosts[0].GetBanTime())
  119. }
  120. defender.AddEvent(testIP, HostEventNoLoginTried)
  121. defender.AddEvent(testIP, HostEventNoLoginTried)
  122. score, err = defender.GetScore(testIP)
  123. assert.NoError(t, err)
  124. assert.Equal(t, 0, score)
  125. banTime, err = defender.GetBanTime(testIP)
  126. assert.NoError(t, err)
  127. assert.NotNil(t, banTime)
  128. hosts, err = defender.GetHosts()
  129. assert.NoError(t, err)
  130. if assert.Len(t, hosts, 1) {
  131. assert.Equal(t, 0, hosts[0].Score)
  132. assert.False(t, hosts[0].BanTime.IsZero())
  133. assert.NotEmpty(t, hosts[0].GetBanTime())
  134. assert.Equal(t, hex.EncodeToString([]byte(testIP)), hosts[0].GetID())
  135. }
  136. host, err = defender.GetHost(testIP)
  137. assert.NoError(t, err)
  138. assert.Equal(t, 0, host.Score)
  139. assert.NotEmpty(t, host.GetBanTime())
  140. // ban time should increase
  141. assert.True(t, defender.IsBanned(testIP))
  142. newBanTime, err := defender.GetBanTime(testIP)
  143. assert.NoError(t, err)
  144. assert.True(t, newBanTime.After(*banTime))
  145. assert.True(t, defender.DeleteHost(testIP))
  146. assert.False(t, defender.DeleteHost(testIP))
  147. // test cleanup
  148. testIP1 := "123.45.67.90"
  149. testIP2 := "123.45.67.91"
  150. testIP3 := "123.45.67.92"
  151. for i := 0; i < 3; i++ {
  152. defender.AddEvent(testIP, HostEventNoLoginTried)
  153. defender.AddEvent(testIP1, HostEventNoLoginTried)
  154. defender.AddEvent(testIP2, HostEventNoLoginTried)
  155. }
  156. hosts, err = defender.GetHosts()
  157. assert.NoError(t, err)
  158. assert.Len(t, hosts, 3)
  159. for _, host := range hosts {
  160. assert.Equal(t, 0, host.Score)
  161. assert.False(t, host.BanTime.IsZero())
  162. assert.NotEmpty(t, host.GetBanTime())
  163. }
  164. defender.AddEvent(testIP3, HostEventLoginFailed)
  165. hosts, err = defender.GetHosts()
  166. assert.NoError(t, err)
  167. assert.Len(t, hosts, 4)
  168. // now set a ban time in the past, so the host will be cleanead up
  169. for _, ip := range []string{testIP1, testIP2} {
  170. err = dataprovider.SetDefenderBanTime(ip, util.GetTimeAsMsSinceEpoch(time.Now().Add(-1*time.Minute)))
  171. assert.NoError(t, err)
  172. }
  173. hosts, err = defender.GetHosts()
  174. assert.NoError(t, err)
  175. assert.Len(t, hosts, 4)
  176. for _, host := range hosts {
  177. switch host.IP {
  178. case testIP:
  179. assert.Equal(t, 0, host.Score)
  180. assert.False(t, host.BanTime.IsZero())
  181. assert.NotEmpty(t, host.GetBanTime())
  182. case testIP3:
  183. assert.Equal(t, 1, host.Score)
  184. assert.True(t, host.BanTime.IsZero())
  185. assert.Empty(t, host.GetBanTime())
  186. default:
  187. assert.Equal(t, 6, host.Score)
  188. assert.True(t, host.BanTime.IsZero())
  189. assert.Empty(t, host.GetBanTime())
  190. }
  191. }
  192. host, err = defender.GetHost(testIP)
  193. assert.NoError(t, err)
  194. assert.Equal(t, 0, host.Score)
  195. assert.False(t, host.BanTime.IsZero())
  196. assert.NotEmpty(t, host.GetBanTime())
  197. host, err = defender.GetHost(testIP3)
  198. assert.NoError(t, err)
  199. assert.Equal(t, 1, host.Score)
  200. assert.True(t, host.BanTime.IsZero())
  201. assert.Empty(t, host.GetBanTime())
  202. // set a negative observation time so the from field in the queries will be in the future
  203. // we still should get the banned hosts
  204. defender.config.ObservationTime = -2
  205. assert.Greater(t, defender.getStartObservationTime(), time.Now().UnixMilli())
  206. hosts, err = defender.GetHosts()
  207. assert.NoError(t, err)
  208. if assert.Len(t, hosts, 1) {
  209. assert.Equal(t, testIP, hosts[0].IP)
  210. assert.Equal(t, 0, hosts[0].Score)
  211. assert.False(t, hosts[0].BanTime.IsZero())
  212. assert.NotEmpty(t, hosts[0].GetBanTime())
  213. }
  214. _, err = defender.GetHost(testIP)
  215. assert.NoError(t, err)
  216. // cleanup db
  217. err = dataprovider.CleanupDefender(util.GetTimeAsMsSinceEpoch(time.Now().Add(10 * time.Minute)))
  218. assert.NoError(t, err)
  219. // the banned host must still be there
  220. hosts, err = defender.GetHosts()
  221. assert.NoError(t, err)
  222. if assert.Len(t, hosts, 1) {
  223. assert.Equal(t, testIP, hosts[0].IP)
  224. assert.Equal(t, 0, hosts[0].Score)
  225. assert.False(t, hosts[0].BanTime.IsZero())
  226. assert.NotEmpty(t, hosts[0].GetBanTime())
  227. }
  228. _, err = defender.GetHost(testIP)
  229. assert.NoError(t, err)
  230. err = dataprovider.SetDefenderBanTime(testIP, util.GetTimeAsMsSinceEpoch(time.Now().Add(-1*time.Minute)))
  231. assert.NoError(t, err)
  232. err = dataprovider.CleanupDefender(util.GetTimeAsMsSinceEpoch(time.Now().Add(10 * time.Minute)))
  233. assert.NoError(t, err)
  234. hosts, err = defender.GetHosts()
  235. assert.NoError(t, err)
  236. assert.Len(t, hosts, 0)
  237. err = os.Remove(slFile)
  238. assert.NoError(t, err)
  239. err = os.Remove(blFile)
  240. assert.NoError(t, err)
  241. }
  242. func TestDbDefenderCleanup(t *testing.T) {
  243. if !isDbDefenderSupported() {
  244. t.Skip("this test is not supported with the current database provider")
  245. }
  246. config := &DefenderConfig{
  247. Enabled: true,
  248. BanTime: 10,
  249. BanTimeIncrement: 2,
  250. Threshold: 5,
  251. ScoreInvalid: 2,
  252. ScoreValid: 1,
  253. ScoreLimitExceeded: 3,
  254. ObservationTime: 15,
  255. EntriesSoftLimit: 1,
  256. EntriesHardLimit: 10,
  257. }
  258. d, err := newDBDefender(config)
  259. assert.NoError(t, err)
  260. defender := d.(*dbDefender)
  261. lastCleanup := defender.getLastCleanup()
  262. assert.True(t, lastCleanup.IsZero())
  263. defender.cleanup()
  264. lastCleanup = defender.getLastCleanup()
  265. assert.False(t, lastCleanup.IsZero())
  266. defender.cleanup()
  267. assert.Equal(t, lastCleanup, defender.getLastCleanup())
  268. defender.setLastCleanup(time.Now().Add(-time.Duration(config.ObservationTime) * time.Minute * 4))
  269. time.Sleep(20 * time.Millisecond)
  270. defender.cleanup()
  271. assert.True(t, lastCleanup.Before(defender.getLastCleanup()))
  272. providerConf := dataprovider.GetProviderConfig()
  273. err = dataprovider.Close()
  274. assert.NoError(t, err)
  275. lastCleanup = time.Now().Add(-time.Duration(config.ObservationTime) * time.Minute * 4)
  276. defender.setLastCleanup(lastCleanup)
  277. defender.cleanup()
  278. // cleanup will fail and so last cleanup should be reset to the previous value
  279. assert.Equal(t, lastCleanup, defender.getLastCleanup())
  280. err = dataprovider.Initialize(providerConf, configDir, true)
  281. assert.NoError(t, err)
  282. }
  283. func isDbDefenderSupported() bool {
  284. // SQLite shares the implementation with other SQL-based provider but it makes no sense
  285. // to use it outside test cases
  286. switch dataprovider.GetProviderStatus().Driver {
  287. case dataprovider.MySQLDataProviderName, dataprovider.PGSQLDataProviderName,
  288. dataprovider.CockroachDataProviderName, dataprovider.SQLiteDataProviderName:
  289. return true
  290. default:
  291. return false
  292. }
  293. }