187 lines
6.7 KiB
Go
187 lines
6.7 KiB
Go
package middleware
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/sha256"
|
|
"fmt"
|
|
"net/http"
|
|
|
|
"github.com/ente-io/museum/ente"
|
|
"github.com/ente-io/museum/pkg/controller"
|
|
"github.com/ente-io/museum/pkg/controller/discord"
|
|
"github.com/ente-io/museum/pkg/repo"
|
|
"github.com/ente-io/museum/pkg/utils/array"
|
|
"github.com/ente-io/museum/pkg/utils/auth"
|
|
"github.com/ente-io/museum/pkg/utils/network"
|
|
"github.com/ente-io/museum/pkg/utils/time"
|
|
"github.com/ente-io/stacktrace"
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/patrickmn/go-cache"
|
|
"github.com/sirupsen/logrus"
|
|
)
|
|
|
|
var passwordWhiteListedURLs = []string{"/public-collection/info", "/public-collection/report-abuse", "/public-collection/verify-password"}
|
|
var whitelistedCollectionShareIDs = []int64{111}
|
|
|
|
// AccessTokenMiddleware intercepts and authenticates incoming requests
|
|
type AccessTokenMiddleware struct {
|
|
PublicCollectionRepo *repo.PublicCollectionRepository
|
|
PublicCollectionCtrl *controller.PublicCollectionController
|
|
CollectionRepo *repo.CollectionRepository
|
|
Cache *cache.Cache
|
|
BillingCtrl *controller.BillingController
|
|
DiscordController *discord.DiscordController
|
|
}
|
|
|
|
// AccessTokenAuthMiddleware returns a middle ware that extracts the `X-Auth-Access-Token`
|
|
// within the header of a request and uses it to validate the access token and set the
|
|
// ente.PublicAccessContext with auth.PublicAccessKey as key
|
|
func (m *AccessTokenMiddleware) AccessTokenAuthMiddleware(urlSanitizer func(_ *gin.Context) string) gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
accessToken := auth.GetAccessToken(c)
|
|
if accessToken == "" {
|
|
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "missing accessToken"})
|
|
return
|
|
}
|
|
clientIP := network.GetClientIP(c)
|
|
userAgent := c.GetHeader("User-Agent")
|
|
var publicCollectionSummary ente.PublicCollectionSummary
|
|
var err error
|
|
|
|
cacheKey := computeHashKeyForList([]string{accessToken, clientIP, userAgent}, ":")
|
|
cachedValue, cacheHit := m.Cache.Get(cacheKey)
|
|
if !cacheHit {
|
|
publicCollectionSummary, err = m.PublicCollectionRepo.GetCollectionSummaryByToken(c, accessToken)
|
|
if err != nil {
|
|
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
|
|
return
|
|
}
|
|
if publicCollectionSummary.IsDisabled {
|
|
c.AbortWithStatusJSON(http.StatusGone, gin.H{"error": "disabled token"})
|
|
return
|
|
}
|
|
// validate if user still has active paid subscription
|
|
if err = m.validateOwnersSubscription(publicCollectionSummary.CollectionID); err != nil {
|
|
logrus.WithError(err).Warn("failed to verify active paid subscription")
|
|
c.AbortWithStatusJSON(http.StatusGone, gin.H{"error": "no active subscription"})
|
|
return
|
|
}
|
|
|
|
// validate device limit
|
|
reached, err := m.isDeviceLimitReached(c, publicCollectionSummary, clientIP, userAgent)
|
|
if err != nil {
|
|
logrus.WithError(err).Error("failed to check device limit")
|
|
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "something went wrong"})
|
|
return
|
|
}
|
|
if reached {
|
|
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{"error": "reached device limit"})
|
|
return
|
|
}
|
|
} else {
|
|
publicCollectionSummary = cachedValue.(ente.PublicCollectionSummary)
|
|
}
|
|
|
|
if publicCollectionSummary.ValidTill > 0 && // expiry time is defined, 0 indicates no expiry
|
|
publicCollectionSummary.ValidTill < time.Microseconds() {
|
|
c.AbortWithStatusJSON(http.StatusGone, gin.H{"error": "expired token"})
|
|
return
|
|
}
|
|
|
|
// checks password protected public collection
|
|
if publicCollectionSummary.PassHash != nil && *publicCollectionSummary.PassHash != "" {
|
|
reqPath := urlSanitizer(c)
|
|
if err = m.validatePassword(c, reqPath, publicCollectionSummary); err != nil {
|
|
logrus.WithError(err).Warn("password validation failed")
|
|
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": err})
|
|
return
|
|
}
|
|
}
|
|
|
|
if !cacheHit {
|
|
m.Cache.Set(cacheKey, publicCollectionSummary, cache.DefaultExpiration)
|
|
}
|
|
|
|
c.Set(auth.PublicAccessKey, ente.PublicAccessContext{
|
|
ID: publicCollectionSummary.ID,
|
|
IP: clientIP,
|
|
UserAgent: userAgent,
|
|
CollectionID: publicCollectionSummary.CollectionID,
|
|
})
|
|
c.Next()
|
|
}
|
|
}
|
|
func (m *AccessTokenMiddleware) validateOwnersSubscription(cID int64) error {
|
|
userID, err := m.CollectionRepo.GetOwnerID(cID)
|
|
if err != nil {
|
|
return stacktrace.Propagate(err, "")
|
|
}
|
|
return m.BillingCtrl.HasActiveSelfOrFamilySubscription(userID)
|
|
}
|
|
|
|
func (m *AccessTokenMiddleware) isDeviceLimitReached(ctx context.Context,
|
|
collectionSummary ente.PublicCollectionSummary, ip string, ua string) (bool, error) {
|
|
// skip deviceLimit check & record keeping for requests via CF worker
|
|
if network.IsCFWorkerIP(ip) {
|
|
return false, nil
|
|
}
|
|
if collectionSummary.DeviceLimit <= 0 { // no device limit was added
|
|
return false, nil
|
|
}
|
|
sharedID := collectionSummary.ID
|
|
hasAccessedInPast, err := m.PublicCollectionRepo.AccessedInPast(ctx, sharedID, ip, ua)
|
|
if err != nil {
|
|
return false, stacktrace.Propagate(err, "")
|
|
}
|
|
// if the device has accessed the url in the past, let it access it now as well, irrespective of device limit.
|
|
if hasAccessedInPast {
|
|
return false, nil
|
|
}
|
|
count, err := m.PublicCollectionRepo.GetUniqueAccessCount(ctx, sharedID)
|
|
if err != nil {
|
|
return false, stacktrace.Propagate(err, "failed to get unique access count")
|
|
}
|
|
|
|
deviceLimit := int64(collectionSummary.DeviceLimit)
|
|
if deviceLimit == controller.DeviceLimitThreshold {
|
|
deviceLimit = controller.DeviceLimitThresholdMultiplier * controller.DeviceLimitThreshold
|
|
}
|
|
|
|
if count >= controller.DeviceLimitWarningThreshold {
|
|
if !array.Int64InList(sharedID, whitelistedCollectionShareIDs) {
|
|
m.DiscordController.NotifyPotentialAbuse(
|
|
fmt.Sprintf("Album exceeds warning threshold: {CollectionID: %d, ShareID: %d}",
|
|
collectionSummary.CollectionID, collectionSummary.ID))
|
|
}
|
|
}
|
|
|
|
if count >= deviceLimit {
|
|
return true, nil
|
|
}
|
|
err = m.PublicCollectionRepo.RecordAccessHistory(ctx, sharedID, ip, ua)
|
|
return false, stacktrace.Propagate(err, "failed to record access history")
|
|
}
|
|
|
|
// validatePassword will verify if the user is provided correct password for the public album
|
|
func (m *AccessTokenMiddleware) validatePassword(c *gin.Context, reqPath string,
|
|
collectionSummary ente.PublicCollectionSummary) error {
|
|
if array.StringInList(reqPath, passwordWhiteListedURLs) {
|
|
return nil
|
|
}
|
|
accessTokenJWT := auth.GetAccessTokenJWT(c)
|
|
if accessTokenJWT == "" {
|
|
return ente.ErrAuthenticationRequired
|
|
}
|
|
return m.PublicCollectionCtrl.ValidateJWTToken(c, accessTokenJWT, *collectionSummary.PassHash)
|
|
}
|
|
|
|
func computeHashKeyForList(list []string, delim string) string {
|
|
var buffer bytes.Buffer
|
|
for i := range list {
|
|
buffer.WriteString(list[i])
|
|
buffer.WriteString(delim)
|
|
}
|
|
sha := sha256.Sum256(buffer.Bytes())
|
|
return fmt.Sprintf("%x\n", sha)
|
|
}
|