497 lines
19 KiB
Go
497 lines
19 KiB
Go
package controller
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"github.com/ente-io/museum/pkg/controller/commonbilling"
|
|
"strconv"
|
|
|
|
"github.com/ente-io/museum/pkg/repo/storagebonus"
|
|
|
|
"github.com/ente-io/museum/pkg/controller/discord"
|
|
"github.com/ente-io/museum/pkg/controller/email"
|
|
"github.com/ente-io/museum/pkg/utils/array"
|
|
"github.com/ente-io/museum/pkg/utils/billing"
|
|
"github.com/ente-io/museum/pkg/utils/network"
|
|
"github.com/ente-io/museum/pkg/utils/time"
|
|
"github.com/ente-io/stacktrace"
|
|
"github.com/gin-gonic/gin"
|
|
log "github.com/sirupsen/logrus"
|
|
"github.com/spf13/viper"
|
|
|
|
"github.com/ente-io/museum/ente"
|
|
"github.com/ente-io/museum/pkg/repo"
|
|
)
|
|
|
|
// BillingController provides abstractions for handling billing related queries
|
|
type BillingController struct {
|
|
BillingPlansPerAccount ente.BillingPlansPerAccount
|
|
BillingRepo *repo.BillingRepository
|
|
UserRepo *repo.UserRepository
|
|
UsageRepo *repo.UsageRepository
|
|
StorageBonusRepo *storagebonus.Repository
|
|
AppStoreController *AppStoreController
|
|
PlayStoreController *PlayStoreController
|
|
StripeController *StripeController
|
|
DiscordController *discord.DiscordController
|
|
EmailNotificationCtrl *email.EmailNotificationController
|
|
CommonBillCtrl *commonbilling.Controller
|
|
}
|
|
|
|
// Return a new instance of BillingController
|
|
func NewBillingController(
|
|
plans ente.BillingPlansPerAccount,
|
|
appStoreController *AppStoreController,
|
|
playStoreController *PlayStoreController,
|
|
stripeController *StripeController,
|
|
discordController *discord.DiscordController,
|
|
emailNotificationCtrl *email.EmailNotificationController,
|
|
billingRepo *repo.BillingRepository,
|
|
userRepo *repo.UserRepository,
|
|
usageRepo *repo.UsageRepository,
|
|
storageBonusRepo *storagebonus.Repository,
|
|
commonBillCtrl *commonbilling.Controller,
|
|
) *BillingController {
|
|
return &BillingController{
|
|
BillingPlansPerAccount: plans,
|
|
BillingRepo: billingRepo,
|
|
UserRepo: userRepo,
|
|
UsageRepo: usageRepo,
|
|
AppStoreController: appStoreController,
|
|
PlayStoreController: playStoreController,
|
|
StripeController: stripeController,
|
|
DiscordController: discordController,
|
|
EmailNotificationCtrl: emailNotificationCtrl,
|
|
StorageBonusRepo: storageBonusRepo,
|
|
CommonBillCtrl: commonBillCtrl,
|
|
}
|
|
}
|
|
|
|
// GetPlansV2 returns the available subscription plans for the given country and stripe account
|
|
func (c *BillingController) GetPlansV2(countryCode string, stripeAccountCountry ente.StripeAccountCountry) []ente.BillingPlan {
|
|
plans := c.getAllPlans(countryCode, stripeAccountCountry)
|
|
result := make([]ente.BillingPlan, 0)
|
|
ids := billing.GetActivePlanIDs()
|
|
for _, plan := range plans {
|
|
if contains(ids, plan.ID) {
|
|
result = append(result, plan)
|
|
}
|
|
}
|
|
return result
|
|
}
|
|
|
|
// GetStripeAccountCountry returns the stripe account country the user's existing plan is from
|
|
// if he doesn't have a stripe subscription then ente.DefaultStripeAccountCountry is returned
|
|
func (c *BillingController) GetStripeAccountCountry(userID int64) (ente.StripeAccountCountry, error) {
|
|
stipeSubInfo, hasStripeSub, err := c.GetUserStripeSubscriptionInfo(userID)
|
|
if err != nil {
|
|
return "", stacktrace.Propagate(err, "")
|
|
}
|
|
if hasStripeSub {
|
|
return stipeSubInfo.AccountCountry, nil
|
|
} else {
|
|
//if user doesn't have a stripe subscription, return the default stripe account country
|
|
return ente.DefaultStripeAccountCountry, nil
|
|
}
|
|
}
|
|
|
|
// GetUserPlans returns the active plans for a user
|
|
func (c *BillingController) GetUserPlans(ctx *gin.Context, userID int64) ([]ente.BillingPlan, error) {
|
|
stripeSubInfo, hasStripeSub, err := c.GetUserStripeSubscriptionInfo(userID)
|
|
if err != nil {
|
|
return []ente.BillingPlan{}, stacktrace.Propagate(err, "Failed to get user's subscription country and stripe account")
|
|
}
|
|
if hasStripeSub {
|
|
return c.GetPlansV2(stripeSubInfo.PlanCountry, stripeSubInfo.AccountCountry), nil
|
|
} else {
|
|
// user doesn't have a stipe subscription, so return the default account plans for the country the user is from
|
|
return c.GetPlansV2(network.GetClientCountry(ctx), ente.DefaultStripeAccountCountry), nil
|
|
}
|
|
}
|
|
|
|
// GetSubscription returns the current subscription for a user if any
|
|
func (c *BillingController) GetSubscription(ctx *gin.Context, userID int64) (ente.Subscription, error) {
|
|
s, err := c.BillingRepo.GetUserSubscription(userID)
|
|
if err != nil {
|
|
return ente.Subscription{}, stacktrace.Propagate(err, "")
|
|
}
|
|
plan, err := c.getPlanForCountry(s, network.GetClientCountry(ctx))
|
|
if err != nil {
|
|
return ente.Subscription{}, stacktrace.Propagate(err, "")
|
|
}
|
|
s.Price = plan.Price
|
|
s.Period = plan.Period
|
|
return s, nil
|
|
}
|
|
|
|
func (c *BillingController) GetRedirectURL(ctx *gin.Context) (string, error) {
|
|
whitelistedRedirectURLs := viper.GetStringSlice("stripe.whitelisted-redirect-urls")
|
|
redirectURL := ctx.Query("redirectURL")
|
|
if len(redirectURL) > 0 && redirectURL[len(redirectURL)-1:] == "/" { // Ignore the trailing slash
|
|
redirectURL = redirectURL[:len(redirectURL)-1]
|
|
}
|
|
for _, ar := range whitelistedRedirectURLs {
|
|
if ar == redirectURL {
|
|
return ar, nil
|
|
}
|
|
}
|
|
return "", stacktrace.Propagate(ente.ErrBadRequest, fmt.Sprintf("not a whitelistedRedirectURL- %s", redirectURL))
|
|
}
|
|
|
|
// GetActiveSubscription returns user's active subscription or throws a error if no active subscription
|
|
func (c *BillingController) GetActiveSubscription(userID int64) (ente.Subscription, error) {
|
|
subscription, err := c.BillingRepo.GetUserSubscription(userID)
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return subscription, ente.ErrNoActiveSubscription
|
|
}
|
|
if err != nil {
|
|
return subscription, stacktrace.Propagate(err, "")
|
|
}
|
|
expiryBuffer := int64(0)
|
|
if value, ok := billing.ProviderToExpiryGracePeriodMap[subscription.PaymentProvider]; ok {
|
|
expiryBuffer = value
|
|
}
|
|
if (subscription.ExpiryTime + expiryBuffer) < time.Microseconds() {
|
|
return subscription, ente.ErrNoActiveSubscription
|
|
}
|
|
return subscription, nil
|
|
}
|
|
|
|
// IsActivePayingSubscriber validates if the current user is paying customer with active subscription
|
|
func (c *BillingController) IsActivePayingSubscriber(userID int64) error {
|
|
subscription, err := c.GetActiveSubscription(userID)
|
|
var subErr error
|
|
if err != nil {
|
|
subErr = stacktrace.Propagate(err, "")
|
|
} else if !billing.IsActivePaidPlan(subscription) {
|
|
subErr = ente.ErrSharingDisabledForFreeAccounts
|
|
}
|
|
if subErr != nil && (errors.Is(subErr, ente.ErrNoActiveSubscription) || errors.Is(subErr, ente.ErrSharingDisabledForFreeAccounts)) {
|
|
storage, storeErr := c.StorageBonusRepo.GetPaidAddonSurplusStorage(context.Background(), userID)
|
|
if storeErr != nil {
|
|
return storeErr
|
|
}
|
|
if *storage > 0 {
|
|
return nil
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// HasActiveSelfOrFamilySubscription validates if the user or user's family admin has active subscription
|
|
func (c *BillingController) HasActiveSelfOrFamilySubscription(userID int64) error {
|
|
var subscriptionUserID int64
|
|
familyAdminID, err := c.UserRepo.GetFamilyAdminID(userID)
|
|
if err != nil {
|
|
return stacktrace.Propagate(err, "")
|
|
}
|
|
if familyAdminID != nil {
|
|
subscriptionUserID = *familyAdminID
|
|
} else {
|
|
subscriptionUserID = userID
|
|
}
|
|
_, err = c.GetActiveSubscription(subscriptionUserID)
|
|
if err != nil {
|
|
if errors.Is(err, ente.ErrNoActiveSubscription) {
|
|
storage, storeErr := c.StorageBonusRepo.GetPaidAddonSurplusStorage(context.Background(), subscriptionUserID)
|
|
if storeErr != nil {
|
|
return storeErr
|
|
}
|
|
if *storage > 0 {
|
|
return nil
|
|
}
|
|
}
|
|
return stacktrace.Propagate(err, "")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *BillingController) GetUserStripeSubscriptionInfo(userID int64) (ente.StripeSubscriptionInfo, bool, error) {
|
|
s, err := c.BillingRepo.GetUserSubscription(userID)
|
|
if err != nil {
|
|
return ente.StripeSubscriptionInfo{}, false, stacktrace.Propagate(err, "")
|
|
}
|
|
// skipping country code extraction for non-stripe subscriptions
|
|
// as they have same product id across countries and hence can't be distinquished
|
|
if s.PaymentProvider != ente.Stripe {
|
|
return ente.StripeSubscriptionInfo{}, false, nil
|
|
}
|
|
_, countryCode, err := c.getPlanWithCountry(s)
|
|
if err != nil {
|
|
return ente.StripeSubscriptionInfo{}, false, stacktrace.Propagate(err, "")
|
|
}
|
|
return ente.StripeSubscriptionInfo{PlanCountry: countryCode, AccountCountry: s.Attributes.StripeAccountCountry}, true, nil
|
|
}
|
|
|
|
// VerifySubscription verifies and returns the verified subscription
|
|
func (c *BillingController) VerifySubscription(
|
|
userID int64,
|
|
paymentProvider ente.PaymentProvider,
|
|
productID string,
|
|
verificationData string) (ente.Subscription, error) {
|
|
if productID == ente.FreePlanProductID {
|
|
return c.BillingRepo.GetUserSubscription(userID)
|
|
}
|
|
var newSubscription ente.Subscription
|
|
var err error
|
|
switch paymentProvider {
|
|
case ente.PlayStore:
|
|
newSubscription, err = c.PlayStoreController.GetVerifiedSubscription(userID, productID, verificationData)
|
|
case ente.AppStore:
|
|
newSubscription, err = c.AppStoreController.GetVerifiedSubscription(userID, productID, verificationData)
|
|
case ente.Stripe:
|
|
newSubscription, err = c.StripeController.GetVerifiedSubscription(userID, verificationData)
|
|
default:
|
|
err = stacktrace.Propagate(ente.ErrBadRequest, "")
|
|
}
|
|
if err != nil {
|
|
return ente.Subscription{}, stacktrace.Propagate(err, "")
|
|
}
|
|
currentSubscription, err := c.BillingRepo.GetUserSubscription(userID)
|
|
if err != nil {
|
|
return ente.Subscription{}, stacktrace.Propagate(err, "")
|
|
}
|
|
newSubscriptionExpiresSooner := newSubscription.ExpiryTime < currentSubscription.ExpiryTime
|
|
isUpgradingFromFreePlan := currentSubscription.ProductID == ente.FreePlanProductID
|
|
hasChangedProductID := currentSubscription.ProductID != newSubscription.ProductID
|
|
isOutdatedPurchase := !isUpgradingFromFreePlan && !hasChangedProductID && newSubscriptionExpiresSooner
|
|
if isOutdatedPurchase {
|
|
// User is reporting an outdated purchase that was already verified
|
|
// no-op
|
|
log.Info("Outdated purchase reported")
|
|
return currentSubscription, nil
|
|
}
|
|
if newSubscription.Storage < currentSubscription.Storage {
|
|
canDowngrade, canDowngradeErr := c.CommonBillCtrl.CanDowngradeToGivenStorage(newSubscription.Storage, userID)
|
|
if canDowngradeErr != nil {
|
|
return ente.Subscription{}, stacktrace.Propagate(canDowngradeErr, "")
|
|
}
|
|
if !canDowngrade {
|
|
return ente.Subscription{}, stacktrace.Propagate(ente.ErrCannotDowngrade, "")
|
|
}
|
|
log.Info("Usage is good")
|
|
}
|
|
if newSubscription.OriginalTransactionID != "" && newSubscription.OriginalTransactionID != "none" {
|
|
existingSub, existingSubErr := c.BillingRepo.GetSubscriptionForTransaction(newSubscription.OriginalTransactionID, paymentProvider)
|
|
if existingSubErr != nil {
|
|
if errors.Is(existingSubErr, sql.ErrNoRows) {
|
|
log.Info("No subscription created yet")
|
|
} else {
|
|
log.Info("Something went wrong")
|
|
log.WithError(existingSubErr).Error("GetSubscriptionForTransaction failed")
|
|
return ente.Subscription{}, stacktrace.Propagate(existingSubErr, "")
|
|
}
|
|
} else {
|
|
if existingSub.UserID != userID {
|
|
log.WithFields(log.Fields{
|
|
"original_transaction_id": existingSub.OriginalTransactionID,
|
|
"existing_user": existingSub.UserID,
|
|
"current_user": userID,
|
|
}).Error("Subscription for given transactionID is attached with different user")
|
|
log.Info("Subscription attached to different user")
|
|
return ente.Subscription{}, stacktrace.Propagate(&ente.ErrSubscriptionAlreadyClaimed,
|
|
fmt.Sprintf("Subscription with txn id %s already associated with user %d", newSubscription.OriginalTransactionID, existingSub.UserID))
|
|
}
|
|
}
|
|
}
|
|
err = c.BillingRepo.ReplaceSubscription(
|
|
currentSubscription.ID,
|
|
newSubscription,
|
|
)
|
|
if err != nil {
|
|
return ente.Subscription{}, stacktrace.Propagate(err, "")
|
|
}
|
|
log.Info("Replaced subscription")
|
|
newSubscription.ID = currentSubscription.ID
|
|
if paymentProvider == ente.PlayStore &&
|
|
newSubscription.OriginalTransactionID != currentSubscription.OriginalTransactionID {
|
|
// Acknowledge to PlayStore in case of upgrades/downgrades/renewals
|
|
err = c.PlayStoreController.AcknowledgeSubscription(newSubscription.ProductID, verificationData)
|
|
if err != nil {
|
|
log.Error("Error acknowledging subscription ", err)
|
|
}
|
|
}
|
|
if isUpgradingFromFreePlan {
|
|
go func() {
|
|
amount := "unknown"
|
|
plan, _, err := c.getPlanWithCountry(newSubscription)
|
|
if err != nil {
|
|
log.Error(err)
|
|
} else {
|
|
amount = plan.Price
|
|
}
|
|
c.DiscordController.NotifyNewSub(userID, string(paymentProvider), amount)
|
|
}()
|
|
go func() {
|
|
c.EmailNotificationCtrl.OnAccountUpgrade(userID)
|
|
}()
|
|
}
|
|
log.Info("Returning new subscription with ID " + strconv.FormatInt(newSubscription.ID, 10))
|
|
return newSubscription, nil
|
|
}
|
|
|
|
func (c *BillingController) getAllPlans(countryCode string, stripeAccountCountry ente.StripeAccountCountry) []ente.BillingPlan {
|
|
if array.StringInList(countryCode, billing.CountriesInEU) {
|
|
countryCode = "EU"
|
|
}
|
|
countryWisePlans := c.BillingPlansPerAccount[stripeAccountCountry]
|
|
if plans, found := countryWisePlans[countryCode]; found {
|
|
return plans
|
|
}
|
|
// unable to find plans for given country code, return plans for default country
|
|
defaultCountry := billing.GetDefaultPlanCountry()
|
|
return countryWisePlans[defaultCountry]
|
|
}
|
|
|
|
func (c *BillingController) UpdateBillingEmail(userID int64, newEmail string) error {
|
|
subscription, err := c.BillingRepo.GetUserSubscription(userID)
|
|
if err != nil {
|
|
return stacktrace.Propagate(err, "")
|
|
}
|
|
hasStripeSubscription := subscription.PaymentProvider == ente.Stripe
|
|
if hasStripeSubscription {
|
|
err = c.StripeController.UpdateBillingEmail(subscription, newEmail)
|
|
if err != nil {
|
|
return stacktrace.Propagate(err, "")
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *BillingController) UpdateSubscription(r ente.UpdateSubscriptionRequest) error {
|
|
subscription, err := c.BillingRepo.GetUserSubscription(r.UserID)
|
|
if err != nil {
|
|
return stacktrace.Propagate(err, "")
|
|
}
|
|
newSubscription := ente.Subscription{
|
|
Storage: r.Storage,
|
|
ExpiryTime: r.ExpiryTime,
|
|
ProductID: r.ProductID,
|
|
PaymentProvider: r.PaymentProvider,
|
|
OriginalTransactionID: r.TransactionID,
|
|
Attributes: r.Attributes,
|
|
}
|
|
err = c.BillingRepo.ReplaceSubscription(subscription.ID, newSubscription)
|
|
if err != nil {
|
|
return stacktrace.Propagate(err, "")
|
|
}
|
|
err = c.BillingRepo.LogAdminTriggeredSubscriptionUpdate(r)
|
|
return stacktrace.Propagate(err, "")
|
|
}
|
|
|
|
func (c *BillingController) HandleAccountDeletion(ctx context.Context, userID int64, logger *log.Entry) (isCancelled bool, err error) {
|
|
logger.Info("updating billing on account deletion")
|
|
subscription, err := c.BillingRepo.GetUserSubscription(userID)
|
|
if err != nil {
|
|
return false, stacktrace.Propagate(err, "")
|
|
}
|
|
billingLogger := logger.WithFields(log.Fields{
|
|
"customer_id": subscription.Attributes.CustomerID,
|
|
"is_cancelled": subscription.Attributes.IsCancelled,
|
|
"original_txn_id": subscription.OriginalTransactionID,
|
|
"payment_provider": subscription.PaymentProvider,
|
|
"product_id": subscription.ProductID,
|
|
"stripe_account_country": subscription.Attributes.StripeAccountCountry,
|
|
})
|
|
billingLogger.Info("subscription fetched")
|
|
// user on free plan, no action required
|
|
if subscription.ProductID == ente.FreePlanProductID {
|
|
billingLogger.Info("user on free plan")
|
|
return true, nil
|
|
}
|
|
// The word "family" here is a misnomer - these are some manually created
|
|
// accounts for very early adopters, and are unrelated to Family Plans.
|
|
// Cancelation of these accounts will require manual intervention. Ideally,
|
|
// we should never be deleting such accounts.
|
|
if subscription.ProductID == ente.FamilyPlanProductID || subscription.ProductID == "" {
|
|
return false, stacktrace.NewError(fmt.Sprintf("unexpected product id %s", subscription.ProductID), "")
|
|
}
|
|
isCancelled = subscription.Attributes.IsCancelled
|
|
// delete customer data from Stripe if user is on paid plan.
|
|
if subscription.PaymentProvider == ente.Stripe {
|
|
err = c.StripeController.CancelSubAndDeleteCustomer(subscription, billingLogger)
|
|
if err != nil {
|
|
return false, stacktrace.Propagate(err, "")
|
|
}
|
|
// on customer deletion, subscription is automatically cancelled
|
|
isCancelled = true
|
|
} else if subscription.PaymentProvider == ente.AppStore || subscription.PaymentProvider == ente.PlayStore {
|
|
logger.Info("Updating originalTransactionID for app/playStore provider")
|
|
err := c.BillingRepo.UpdateTransactionIDOnDeletion(userID)
|
|
if err != nil {
|
|
return false, stacktrace.Propagate(err, "")
|
|
}
|
|
}
|
|
return isCancelled, nil
|
|
}
|
|
|
|
func (c *BillingController) getPlanWithCountry(s ente.Subscription) (ente.BillingPlan, string, error) {
|
|
var allPlans ente.BillingPlansPerCountry
|
|
if s.PaymentProvider == ente.Stripe {
|
|
allPlans = c.BillingPlansPerAccount[s.Attributes.StripeAccountCountry]
|
|
} else {
|
|
allPlans = c.BillingPlansPerAccount[ente.DefaultStripeAccountCountry]
|
|
}
|
|
subProductID := s.ProductID
|
|
for country, plans := range allPlans {
|
|
for _, plan := range plans {
|
|
if s.PaymentProvider == ente.Stripe && subProductID == plan.StripeID {
|
|
return plan, country, nil
|
|
} else if s.PaymentProvider == ente.PlayStore && subProductID == plan.AndroidID {
|
|
return plan, country, nil
|
|
} else if s.PaymentProvider == ente.AppStore && subProductID == plan.IOSID {
|
|
return plan, country, nil
|
|
} else if (s.PaymentProvider == ente.BitPay || s.PaymentProvider == ente.Paypal) && subProductID == plan.ID {
|
|
return plan, country, nil
|
|
}
|
|
}
|
|
}
|
|
if s.ProductID == ente.FreePlanProductID || s.ProductID == ente.FamilyPlanProductID {
|
|
return ente.BillingPlan{Period: ente.PeriodYear}, "", nil
|
|
}
|
|
|
|
return ente.BillingPlan{}, "", stacktrace.Propagate(ente.ErrNotFound, "unable to get plan for subscription")
|
|
}
|
|
|
|
func (c *BillingController) getPlanForCountry(s ente.Subscription, countryCode string) (ente.BillingPlan, error) {
|
|
var allPlans []ente.BillingPlan
|
|
if s.PaymentProvider == ente.Stripe {
|
|
allPlans = c.getAllPlans(countryCode, s.Attributes.StripeAccountCountry)
|
|
} else {
|
|
allPlans = c.getAllPlans(countryCode, ente.DefaultStripeAccountCountry)
|
|
}
|
|
subProductID := s.ProductID
|
|
for _, plan := range allPlans {
|
|
if s.PaymentProvider == ente.Stripe && subProductID == plan.StripeID {
|
|
return plan, nil
|
|
} else if s.PaymentProvider == ente.PlayStore && subProductID == plan.AndroidID {
|
|
return plan, nil
|
|
} else if s.PaymentProvider == ente.AppStore && subProductID == plan.IOSID {
|
|
return plan, nil
|
|
} else if (s.PaymentProvider == ente.BitPay || s.PaymentProvider == ente.Paypal) && subProductID == plan.ID {
|
|
return plan, nil
|
|
}
|
|
}
|
|
if s.ProductID == ente.FreePlanProductID || s.ProductID == ente.FamilyPlanProductID {
|
|
return ente.BillingPlan{Period: ente.PeriodYear}, nil
|
|
}
|
|
|
|
// If request has a different `countryCode` because the user is traveling, and we're unable to find a plan for that country,
|
|
// fallback to the previous logic for finding a plan.
|
|
plan, _, err := c.getPlanWithCountry(s)
|
|
if err != nil {
|
|
return ente.BillingPlan{}, stacktrace.Propagate(err, "")
|
|
}
|
|
return plan, nil
|
|
}
|
|
|
|
func contains(planIDs []string, planID string) bool {
|
|
for _, id := range planIDs {
|
|
if id == planID {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|