resetcode.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. package httpd
  2. import (
  3. "encoding/json"
  4. "sync"
  5. "time"
  6. "github.com/drakkan/sftpgo/v2/dataprovider"
  7. "github.com/drakkan/sftpgo/v2/logger"
  8. "github.com/drakkan/sftpgo/v2/util"
  9. )
  10. var (
  11. resetCodeLifespan = 10 * time.Minute
  12. resetCodesMgr resetCodeManager
  13. )
  14. type resetCodeManager interface {
  15. Add(code *resetCode) error
  16. Get(code string) (*resetCode, error)
  17. Delete(code string) error
  18. Cleanup()
  19. }
  20. func newResetCodeManager(isShared int) resetCodeManager {
  21. if isShared == 1 {
  22. logger.Info(logSender, "", "using provider reset code manager")
  23. return &dbResetCodeManager{}
  24. }
  25. logger.Info(logSender, "", "using memory reset code manager")
  26. return &memoryResetCodeManager{}
  27. }
  28. type resetCode struct {
  29. Code string `json:"code"`
  30. Username string `json:"username"`
  31. IsAdmin bool `json:"is_admin"`
  32. ExpiresAt time.Time `json:"expires_at"`
  33. }
  34. func newResetCode(username string, isAdmin bool) *resetCode {
  35. return &resetCode{
  36. Code: util.GenerateUniqueID(),
  37. Username: username,
  38. IsAdmin: isAdmin,
  39. ExpiresAt: time.Now().Add(resetCodeLifespan).UTC(),
  40. }
  41. }
  42. func (c *resetCode) isExpired() bool {
  43. return c.ExpiresAt.Before(time.Now().UTC())
  44. }
  45. type memoryResetCodeManager struct {
  46. resetCodes sync.Map
  47. }
  48. func (m *memoryResetCodeManager) Add(code *resetCode) error {
  49. m.resetCodes.Store(code.Code, code)
  50. return nil
  51. }
  52. func (m *memoryResetCodeManager) Get(code string) (*resetCode, error) {
  53. c, ok := m.resetCodes.Load(code)
  54. if !ok {
  55. return nil, util.NewRecordNotFoundError("reset code not found")
  56. }
  57. return c.(*resetCode), nil
  58. }
  59. func (m *memoryResetCodeManager) Delete(code string) error {
  60. m.resetCodes.Delete(code)
  61. return nil
  62. }
  63. func (m *memoryResetCodeManager) Cleanup() {
  64. m.resetCodes.Range(func(key, value any) bool {
  65. c, ok := value.(*resetCode)
  66. if !ok || c.isExpired() {
  67. m.resetCodes.Delete(key)
  68. }
  69. return true
  70. })
  71. }
  72. type dbResetCodeManager struct{}
  73. func (m *dbResetCodeManager) Add(code *resetCode) error {
  74. session := dataprovider.Session{
  75. Key: code.Code,
  76. Data: code,
  77. Type: dataprovider.SessionTypeResetCode,
  78. Timestamp: util.GetTimeAsMsSinceEpoch(code.ExpiresAt),
  79. }
  80. return dataprovider.AddSharedSession(session)
  81. }
  82. func (m *dbResetCodeManager) Get(code string) (*resetCode, error) {
  83. session, err := dataprovider.GetSharedSession(code)
  84. if err != nil {
  85. return nil, err
  86. }
  87. if session.Timestamp < util.GetTimeAsMsSinceEpoch(time.Now()) {
  88. // expired
  89. return nil, util.NewRecordNotFoundError("reset code expired")
  90. }
  91. return m.decodeData(session.Data)
  92. }
  93. func (m *dbResetCodeManager) decodeData(data any) (*resetCode, error) {
  94. if val, ok := data.([]byte); ok {
  95. c := &resetCode{}
  96. err := json.Unmarshal(val, c)
  97. return c, err
  98. }
  99. logger.Error(logSender, "", "invalid reset code data type %T", data)
  100. return nil, util.NewRecordNotFoundError("invalid reset code")
  101. }
  102. func (m *dbResetCodeManager) Delete(code string) error {
  103. return dataprovider.DeleteSharedSession(code)
  104. }
  105. func (m *dbResetCodeManager) Cleanup() {
  106. dataprovider.CleanupSharedSessions(dataprovider.SessionTypeResetCode, time.Now()) //nolint:errcheck
  107. }