defenderdb_test.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. // Copyright (C) 2019-2023 Nicola Murino
  2. //
  3. // This program is free software: you can redistribute it and/or modify
  4. // it under the terms of the GNU Affero General Public License as published
  5. // by the Free Software Foundation, version 3.
  6. //
  7. // This program is distributed in the hope that it will be useful,
  8. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  9. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  10. // GNU Affero General Public License for more details.
  11. //
  12. // You should have received a copy of the GNU Affero General Public License
  13. // along with this program. If not, see <https://www.gnu.org/licenses/>.
  14. package common
  15. import (
  16. "encoding/hex"
  17. "testing"
  18. "time"
  19. "github.com/stretchr/testify/assert"
  20. "github.com/drakkan/sftpgo/v2/internal/dataprovider"
  21. "github.com/drakkan/sftpgo/v2/internal/util"
  22. )
  23. func TestBasicDbDefender(t *testing.T) {
  24. if !isDbDefenderSupported() {
  25. t.Skip("this test is not supported with the current database provider")
  26. }
  27. entries := []dataprovider.IPListEntry{
  28. {
  29. IPOrNet: "172.16.1.1/32",
  30. Type: dataprovider.IPListTypeDefender,
  31. Mode: dataprovider.ListModeDeny,
  32. },
  33. {
  34. IPOrNet: "172.16.1.2/32",
  35. Type: dataprovider.IPListTypeDefender,
  36. Mode: dataprovider.ListModeDeny,
  37. },
  38. {
  39. IPOrNet: "10.8.0.0/24",
  40. Type: dataprovider.IPListTypeDefender,
  41. Mode: dataprovider.ListModeDeny,
  42. },
  43. {
  44. IPOrNet: "172.16.1.3/32",
  45. Type: dataprovider.IPListTypeDefender,
  46. Mode: dataprovider.ListModeAllow,
  47. },
  48. {
  49. IPOrNet: "172.16.1.4/32",
  50. Type: dataprovider.IPListTypeDefender,
  51. Mode: dataprovider.ListModeAllow,
  52. },
  53. {
  54. IPOrNet: "192.168.8.0/24",
  55. Type: dataprovider.IPListTypeDefender,
  56. Mode: dataprovider.ListModeAllow,
  57. },
  58. }
  59. for idx := range entries {
  60. e := entries[idx]
  61. err := dataprovider.AddIPListEntry(&e, "", "", "")
  62. assert.NoError(t, err)
  63. }
  64. config := &DefenderConfig{
  65. Enabled: true,
  66. BanTime: 10,
  67. BanTimeIncrement: 2,
  68. Threshold: 5,
  69. ScoreInvalid: 2,
  70. ScoreValid: 1,
  71. ScoreNoAuth: 2,
  72. ScoreLimitExceeded: 3,
  73. ObservationTime: 15,
  74. EntriesSoftLimit: 1,
  75. EntriesHardLimit: 10,
  76. }
  77. d, err := newDBDefender(config)
  78. assert.NoError(t, err)
  79. defender := d.(*dbDefender)
  80. assert.True(t, defender.IsBanned("172.16.1.1", ProtocolFTP))
  81. assert.False(t, defender.IsBanned("172.16.1.10", ProtocolSSH))
  82. assert.False(t, defender.IsBanned("10.8.1.3", ProtocolHTTP))
  83. assert.True(t, defender.IsBanned("10.8.0.4", ProtocolWebDAV))
  84. assert.False(t, defender.IsBanned("invalid ip", ProtocolSSH))
  85. hosts, err := defender.GetHosts()
  86. assert.NoError(t, err)
  87. assert.Len(t, hosts, 0)
  88. _, err = defender.GetHost("10.8.0.3")
  89. assert.Error(t, err)
  90. defender.AddEvent("172.16.1.4", ProtocolSSH, HostEventLoginFailed)
  91. defender.AddEvent("192.168.8.4", ProtocolSSH, HostEventUserNotFound)
  92. defender.AddEvent("172.16.1.3", ProtocolSSH, HostEventLimitExceeded)
  93. hosts, err = defender.GetHosts()
  94. assert.NoError(t, err)
  95. assert.Len(t, hosts, 0)
  96. assert.True(t, defender.getLastCleanup().IsZero())
  97. testIP := "123.45.67.89"
  98. defender.AddEvent(testIP, ProtocolSSH, HostEventLoginFailed)
  99. lastCleanup := defender.getLastCleanup()
  100. assert.False(t, lastCleanup.IsZero())
  101. score, err := defender.GetScore(testIP)
  102. assert.NoError(t, err)
  103. assert.Equal(t, 1, score)
  104. hosts, err = defender.GetHosts()
  105. assert.NoError(t, err)
  106. if assert.Len(t, hosts, 1) {
  107. assert.Equal(t, 1, hosts[0].Score)
  108. assert.True(t, hosts[0].BanTime.IsZero())
  109. assert.Empty(t, hosts[0].GetBanTime())
  110. }
  111. host, err := defender.GetHost(testIP)
  112. assert.NoError(t, err)
  113. assert.Equal(t, 1, host.Score)
  114. assert.Empty(t, host.GetBanTime())
  115. banTime, err := defender.GetBanTime(testIP)
  116. assert.NoError(t, err)
  117. assert.Nil(t, banTime)
  118. defender.AddEvent(testIP, ProtocolSSH, HostEventLimitExceeded)
  119. score, err = defender.GetScore(testIP)
  120. assert.NoError(t, err)
  121. assert.Equal(t, 4, score)
  122. hosts, err = defender.GetHosts()
  123. assert.NoError(t, err)
  124. if assert.Len(t, hosts, 1) {
  125. assert.Equal(t, 4, hosts[0].Score)
  126. assert.True(t, hosts[0].BanTime.IsZero())
  127. assert.Empty(t, hosts[0].GetBanTime())
  128. }
  129. defender.AddEvent(testIP, ProtocolSSH, HostEventNoLoginTried)
  130. defender.AddEvent(testIP, ProtocolSSH, HostEventNoLoginTried)
  131. score, err = defender.GetScore(testIP)
  132. assert.NoError(t, err)
  133. assert.Equal(t, 0, score)
  134. banTime, err = defender.GetBanTime(testIP)
  135. assert.NoError(t, err)
  136. assert.NotNil(t, banTime)
  137. hosts, err = defender.GetHosts()
  138. assert.NoError(t, err)
  139. if assert.Len(t, hosts, 1) {
  140. assert.Equal(t, 0, hosts[0].Score)
  141. assert.False(t, hosts[0].BanTime.IsZero())
  142. assert.NotEmpty(t, hosts[0].GetBanTime())
  143. assert.Equal(t, hex.EncodeToString([]byte(testIP)), hosts[0].GetID())
  144. }
  145. host, err = defender.GetHost(testIP)
  146. assert.NoError(t, err)
  147. assert.Equal(t, 0, host.Score)
  148. assert.NotEmpty(t, host.GetBanTime())
  149. // ban time should increase
  150. assert.True(t, defender.IsBanned(testIP, ProtocolSSH))
  151. newBanTime, err := defender.GetBanTime(testIP)
  152. assert.NoError(t, err)
  153. assert.True(t, newBanTime.After(*banTime))
  154. assert.True(t, defender.DeleteHost(testIP))
  155. assert.False(t, defender.DeleteHost(testIP))
  156. // test cleanup
  157. testIP1 := "123.45.67.90"
  158. testIP2 := "123.45.67.91"
  159. testIP3 := "123.45.67.92"
  160. for i := 0; i < 3; i++ {
  161. defender.AddEvent(testIP, ProtocolSSH, HostEventUserNotFound)
  162. defender.AddEvent(testIP1, ProtocolSSH, HostEventNoLoginTried)
  163. defender.AddEvent(testIP2, ProtocolSSH, HostEventUserNotFound)
  164. }
  165. hosts, err = defender.GetHosts()
  166. assert.NoError(t, err)
  167. assert.Len(t, hosts, 3)
  168. for _, host := range hosts {
  169. assert.Equal(t, 0, host.Score)
  170. assert.False(t, host.BanTime.IsZero())
  171. assert.NotEmpty(t, host.GetBanTime())
  172. }
  173. defender.AddEvent(testIP3, ProtocolSSH, HostEventLoginFailed)
  174. hosts, err = defender.GetHosts()
  175. assert.NoError(t, err)
  176. assert.Len(t, hosts, 4)
  177. // now set a ban time in the past, so the host will be cleanead up
  178. for _, ip := range []string{testIP1, testIP2} {
  179. err = dataprovider.SetDefenderBanTime(ip, util.GetTimeAsMsSinceEpoch(time.Now().Add(-1*time.Minute)))
  180. assert.NoError(t, err)
  181. }
  182. hosts, err = defender.GetHosts()
  183. assert.NoError(t, err)
  184. assert.Len(t, hosts, 4)
  185. for _, host := range hosts {
  186. switch host.IP {
  187. case testIP:
  188. assert.Equal(t, 0, host.Score)
  189. assert.False(t, host.BanTime.IsZero())
  190. assert.NotEmpty(t, host.GetBanTime())
  191. case testIP3:
  192. assert.Equal(t, 1, host.Score)
  193. assert.True(t, host.BanTime.IsZero())
  194. assert.Empty(t, host.GetBanTime())
  195. default:
  196. assert.Equal(t, 6, host.Score)
  197. assert.True(t, host.BanTime.IsZero())
  198. assert.Empty(t, host.GetBanTime())
  199. }
  200. }
  201. host, err = defender.GetHost(testIP)
  202. assert.NoError(t, err)
  203. assert.Equal(t, 0, host.Score)
  204. assert.False(t, host.BanTime.IsZero())
  205. assert.NotEmpty(t, host.GetBanTime())
  206. host, err = defender.GetHost(testIP3)
  207. assert.NoError(t, err)
  208. assert.Equal(t, 1, host.Score)
  209. assert.True(t, host.BanTime.IsZero())
  210. assert.Empty(t, host.GetBanTime())
  211. // set a negative observation time so the from field in the queries will be in the future
  212. // we still should get the banned hosts
  213. defender.config.ObservationTime = -2
  214. assert.Greater(t, defender.getStartObservationTime(), time.Now().UnixMilli())
  215. hosts, err = defender.GetHosts()
  216. assert.NoError(t, err)
  217. if assert.Len(t, hosts, 1) {
  218. assert.Equal(t, testIP, hosts[0].IP)
  219. assert.Equal(t, 0, hosts[0].Score)
  220. assert.False(t, hosts[0].BanTime.IsZero())
  221. assert.NotEmpty(t, hosts[0].GetBanTime())
  222. }
  223. _, err = defender.GetHost(testIP)
  224. assert.NoError(t, err)
  225. // cleanup db
  226. err = dataprovider.CleanupDefender(util.GetTimeAsMsSinceEpoch(time.Now().Add(10 * time.Minute)))
  227. assert.NoError(t, err)
  228. // the banned host must still be there
  229. hosts, err = defender.GetHosts()
  230. assert.NoError(t, err)
  231. if assert.Len(t, hosts, 1) {
  232. assert.Equal(t, testIP, hosts[0].IP)
  233. assert.Equal(t, 0, hosts[0].Score)
  234. assert.False(t, hosts[0].BanTime.IsZero())
  235. assert.NotEmpty(t, hosts[0].GetBanTime())
  236. }
  237. _, err = defender.GetHost(testIP)
  238. assert.NoError(t, err)
  239. err = dataprovider.SetDefenderBanTime(testIP, util.GetTimeAsMsSinceEpoch(time.Now().Add(-1*time.Minute)))
  240. assert.NoError(t, err)
  241. err = dataprovider.CleanupDefender(util.GetTimeAsMsSinceEpoch(time.Now().Add(10 * time.Minute)))
  242. assert.NoError(t, err)
  243. hosts, err = defender.GetHosts()
  244. assert.NoError(t, err)
  245. assert.Len(t, hosts, 0)
  246. for _, e := range entries {
  247. err := dataprovider.DeleteIPListEntry(e.IPOrNet, e.Type, "", "", "")
  248. assert.NoError(t, err)
  249. }
  250. }
  251. func TestDbDefenderCleanup(t *testing.T) {
  252. if !isDbDefenderSupported() {
  253. t.Skip("this test is not supported with the current database provider")
  254. }
  255. config := &DefenderConfig{
  256. Enabled: true,
  257. BanTime: 10,
  258. BanTimeIncrement: 2,
  259. Threshold: 5,
  260. ScoreInvalid: 2,
  261. ScoreValid: 1,
  262. ScoreLimitExceeded: 3,
  263. ObservationTime: 15,
  264. EntriesSoftLimit: 1,
  265. EntriesHardLimit: 10,
  266. }
  267. d, err := newDBDefender(config)
  268. assert.NoError(t, err)
  269. defender := d.(*dbDefender)
  270. lastCleanup := defender.getLastCleanup()
  271. assert.True(t, lastCleanup.IsZero())
  272. defender.cleanup()
  273. lastCleanup = defender.getLastCleanup()
  274. assert.False(t, lastCleanup.IsZero())
  275. defender.cleanup()
  276. assert.Equal(t, lastCleanup, defender.getLastCleanup())
  277. defender.setLastCleanup(time.Time{})
  278. assert.True(t, defender.getLastCleanup().IsZero())
  279. defender.setLastCleanup(time.Now().Add(-time.Duration(config.ObservationTime) * time.Minute * 4))
  280. time.Sleep(20 * time.Millisecond)
  281. defender.cleanup()
  282. assert.True(t, lastCleanup.Before(defender.getLastCleanup()))
  283. providerConf := dataprovider.GetProviderConfig()
  284. err = dataprovider.Close()
  285. assert.NoError(t, err)
  286. lastCleanup = util.GetTimeFromMsecSinceEpoch(time.Now().Add(-time.Duration(config.ObservationTime) * time.Minute * 4).UnixMilli())
  287. defender.setLastCleanup(lastCleanup)
  288. defender.cleanup()
  289. // cleanup will fail and so last cleanup should be reset to the previous value
  290. assert.Equal(t, lastCleanup, defender.getLastCleanup())
  291. err = dataprovider.Initialize(providerConf, configDir, true)
  292. assert.NoError(t, err)
  293. }
  294. func isDbDefenderSupported() bool {
  295. // SQLite shares the implementation with other SQL-based provider but it makes no sense
  296. // to use it outside test cases
  297. switch dataprovider.GetProviderStatus().Driver {
  298. case dataprovider.MySQLDataProviderName, dataprovider.PGSQLDataProviderName,
  299. dataprovider.CockroachDataProviderName, dataprovider.SQLiteDataProviderName:
  300. return true
  301. default:
  302. return false
  303. }
  304. }