parent
a5340764a8
commit
12b9ac4db6
2 changed files with 52 additions and 11 deletions
|
@ -178,7 +178,8 @@ func main() {
|
|||
authCache := cache.New(1*time.Minute, 15*time.Minute)
|
||||
accessTokenCache := cache.New(1*time.Minute, 15*time.Minute)
|
||||
discordController := discord.NewDiscordController(userRepo, hostName, environment)
|
||||
rateLimiter := middleware.NewRateLimitMiddleware(discordController)
|
||||
rateLimiter := middleware.NewRateLimitMiddleware(discordController, 5000, 5*time.Second)
|
||||
defer rateLimiter.Stop()
|
||||
|
||||
emailNotificationCtrl := &email.EmailNotificationController{
|
||||
UserRepo: userRepo,
|
||||
|
@ -360,22 +361,22 @@ func main() {
|
|||
server.Use(requestid.New(), middleware.Logger(urlSanitizer), cors(), gzip.Gzip(gzip.DefaultCompression), middleware.PanicRecover())
|
||||
|
||||
publicAPI := server.Group("/")
|
||||
publicAPI.Use(rateLimiter.APIRateLimitMiddleware(urlSanitizer))
|
||||
publicAPI.Use(rateLimiter.GlobalRateLimiter(), rateLimiter.APIRateLimitMiddleware(urlSanitizer))
|
||||
|
||||
privateAPI := server.Group("/")
|
||||
privateAPI.Use(authMiddleware.TokenAuthMiddleware(nil), rateLimiter.APIRateLimitForUserMiddleware(urlSanitizer))
|
||||
privateAPI.Use(rateLimiter.GlobalRateLimiter(), authMiddleware.TokenAuthMiddleware(nil), rateLimiter.APIRateLimitForUserMiddleware(urlSanitizer))
|
||||
|
||||
adminAPI := server.Group("/admin")
|
||||
adminAPI.Use(authMiddleware.TokenAuthMiddleware(nil), authMiddleware.AdminAuthMiddleware())
|
||||
adminAPI.Use(rateLimiter.GlobalRateLimiter(), authMiddleware.TokenAuthMiddleware(nil), authMiddleware.AdminAuthMiddleware())
|
||||
paymentJwtAuthAPI := server.Group("/")
|
||||
paymentJwtAuthAPI.Use(authMiddleware.TokenAuthMiddleware(jwt.PAYMENT.Ptr()))
|
||||
paymentJwtAuthAPI.Use(rateLimiter.GlobalRateLimiter(), authMiddleware.TokenAuthMiddleware(jwt.PAYMENT.Ptr()))
|
||||
|
||||
familiesJwtAuthAPI := server.Group("/")
|
||||
//The middleware order matters. First, the userID must be set in the context, so that we can apply limit for user.
|
||||
familiesJwtAuthAPI.Use(authMiddleware.TokenAuthMiddleware(jwt.FAMILIES.Ptr()), rateLimiter.APIRateLimitForUserMiddleware(urlSanitizer))
|
||||
familiesJwtAuthAPI.Use(rateLimiter.GlobalRateLimiter(), authMiddleware.TokenAuthMiddleware(jwt.FAMILIES.Ptr()), rateLimiter.APIRateLimitForUserMiddleware(urlSanitizer))
|
||||
|
||||
publicCollectionAPI := server.Group("/public-collection")
|
||||
publicCollectionAPI.Use(accessTokenMiddleware.AccessTokenAuthMiddleware(urlSanitizer))
|
||||
publicCollectionAPI.Use(rateLimiter.GlobalRateLimiter(), accessTokenMiddleware.AccessTokenAuthMiddleware(urlSanitizer))
|
||||
|
||||
healthCheckHandler := &api.HealthCheckHandler{
|
||||
DB: db,
|
||||
|
@ -472,7 +473,7 @@ func main() {
|
|||
privateAPI.DELETE("/users/delete", userHandler.DeleteUser)
|
||||
|
||||
accountsJwtAuthAPI := server.Group("/")
|
||||
accountsJwtAuthAPI.Use(authMiddleware.TokenAuthMiddleware(jwt.ACCOUNTS.Ptr()), rateLimiter.APIRateLimitForUserMiddleware(urlSanitizer))
|
||||
accountsJwtAuthAPI.Use(rateLimiter.GlobalRateLimiter(), authMiddleware.TokenAuthMiddleware(jwt.ACCOUNTS.Ptr()), rateLimiter.APIRateLimitForUserMiddleware(urlSanitizer))
|
||||
passkeysHandler := &api.PasskeyHandler{
|
||||
Controller: passkeyCtrl,
|
||||
}
|
||||
|
@ -531,7 +532,7 @@ func main() {
|
|||
|
||||
castCtrl := cast.NewController(&castDb, accessCtrl)
|
||||
castMiddleware := middleware.CastMiddleware{CastCtrl: castCtrl, Cache: authCache}
|
||||
castAPI.Use(castMiddleware.CastAuthMiddleware())
|
||||
castAPI.Use(rateLimiter.GlobalRateLimiter(), castMiddleware.CastAuthMiddleware())
|
||||
|
||||
castHandler := &api.CastHandler{
|
||||
CollectionCtrl: collectionController,
|
||||
|
|
|
@ -5,6 +5,8 @@ import (
|
|||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/ente-io/museum/pkg/controller/discord"
|
||||
"github.com/ente-io/museum/pkg/utils/auth"
|
||||
|
@ -20,14 +22,40 @@ type RateLimitMiddleware struct {
|
|||
limit10ReqPerMin *limiter.Limiter
|
||||
limit200ReqPerSec *limiter.Limiter
|
||||
discordCtrl *discord.DiscordController
|
||||
count int64 // Use int64 for atomic operations
|
||||
limit int64
|
||||
reset time.Duration
|
||||
ticker *time.Ticker
|
||||
}
|
||||
|
||||
func NewRateLimitMiddleware(discordCtrl *discord.DiscordController) *RateLimitMiddleware {
|
||||
return &RateLimitMiddleware{
|
||||
func NewRateLimitMiddleware(discordCtrl *discord.DiscordController, limit int64, reset time.Duration) *RateLimitMiddleware {
|
||||
rl := &RateLimitMiddleware{
|
||||
limit10ReqPerMin: rateLimiter("10-M"),
|
||||
limit200ReqPerSec: rateLimiter("200-S"),
|
||||
discordCtrl: discordCtrl,
|
||||
limit: limit,
|
||||
reset: reset,
|
||||
ticker: time.NewTicker(reset),
|
||||
}
|
||||
go func() {
|
||||
for range rl.ticker.C {
|
||||
atomic.StoreInt64(&rl.count, 0) // Reset the count every reset interval
|
||||
}
|
||||
}()
|
||||
return rl
|
||||
}
|
||||
|
||||
// Increment increments the counter in a thread-safe manner.
|
||||
// Returns true if the increment was within the rate limit, false if the rate limit was exceeded.
|
||||
func (r *RateLimitMiddleware) Increment() bool {
|
||||
// Atomically increment the count
|
||||
newCount := atomic.AddInt64(&r.count, 1)
|
||||
return newCount <= r.limit
|
||||
}
|
||||
|
||||
// Stop the internal ticker, effectively stopping the rate limiter.
|
||||
func (r *RateLimitMiddleware) Stop() {
|
||||
r.ticker.Stop()
|
||||
}
|
||||
|
||||
// rateLimiter will return instance of limiter.Limiter based on internal <limit>-<period>
|
||||
|
@ -44,6 +72,18 @@ func rateLimiter(interval string) *limiter.Limiter {
|
|||
return instance
|
||||
}
|
||||
|
||||
// GlobalRateLimiter rate limits all requests to the server, regardless of the endpoint.
|
||||
func (r *RateLimitMiddleware) GlobalRateLimiter() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if !r.Increment() {
|
||||
go r.discordCtrl.NotifyPotentialAbuse("Global rate limit breached")
|
||||
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{"error": "Rate limit breached, try later"})
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// APIRateLimitMiddleware only rate limits sensitive public endpoints which have a higher risk
|
||||
// of abuse by any bad actor.
|
||||
func (r *RateLimitMiddleware) APIRateLimitMiddleware(urlSanitizer func(_ *gin.Context) string) gin.HandlerFunc {
|
||||
|
|
Loading…
Add table
Reference in a new issue