package controller import ( "context" "database/sql" "errors" "fmt" "github.com/ente-io/museum/pkg/controller/commonbilling" "strconv" "github.com/ente-io/museum/pkg/repo/storagebonus" "github.com/ente-io/museum/pkg/controller/discord" "github.com/ente-io/museum/pkg/controller/email" "github.com/ente-io/museum/pkg/utils/array" "github.com/ente-io/museum/pkg/utils/billing" "github.com/ente-io/museum/pkg/utils/network" "github.com/ente-io/museum/pkg/utils/time" "github.com/ente-io/stacktrace" "github.com/gin-gonic/gin" log "github.com/sirupsen/logrus" "github.com/spf13/viper" "github.com/ente-io/museum/ente" "github.com/ente-io/museum/pkg/repo" ) // BillingController provides abstractions for handling billing related queries type BillingController struct { BillingPlansPerAccount ente.BillingPlansPerAccount BillingRepo *repo.BillingRepository UserRepo *repo.UserRepository UsageRepo *repo.UsageRepository StorageBonusRepo *storagebonus.Repository AppStoreController *AppStoreController PlayStoreController *PlayStoreController StripeController *StripeController DiscordController *discord.DiscordController EmailNotificationCtrl *email.EmailNotificationController CommonBillCtrl *commonbilling.Controller } // Return a new instance of BillingController func NewBillingController( plans ente.BillingPlansPerAccount, appStoreController *AppStoreController, playStoreController *PlayStoreController, stripeController *StripeController, discordController *discord.DiscordController, emailNotificationCtrl *email.EmailNotificationController, billingRepo *repo.BillingRepository, userRepo *repo.UserRepository, usageRepo *repo.UsageRepository, storageBonusRepo *storagebonus.Repository, commonBillCtrl *commonbilling.Controller, ) *BillingController { return &BillingController{ BillingPlansPerAccount: plans, BillingRepo: billingRepo, UserRepo: userRepo, UsageRepo: usageRepo, AppStoreController: appStoreController, PlayStoreController: playStoreController, StripeController: stripeController, DiscordController: discordController, EmailNotificationCtrl: emailNotificationCtrl, StorageBonusRepo: storageBonusRepo, CommonBillCtrl: commonBillCtrl, } } // GetPlansV2 returns the available subscription plans for the given country and stripe account func (c *BillingController) GetPlansV2(countryCode string, stripeAccountCountry ente.StripeAccountCountry) []ente.BillingPlan { plans := c.getAllPlans(countryCode, stripeAccountCountry) result := make([]ente.BillingPlan, 0) ids := billing.GetActivePlanIDs() for _, plan := range plans { if contains(ids, plan.ID) { result = append(result, plan) } } return result } // GetStripeAccountCountry returns the stripe account country the user's existing plan is from // if he doesn't have a stripe subscription then ente.DefaultStripeAccountCountry is returned func (c *BillingController) GetStripeAccountCountry(userID int64) (ente.StripeAccountCountry, error) { stipeSubInfo, hasStripeSub, err := c.GetUserStripeSubscriptionInfo(userID) if err != nil { return "", stacktrace.Propagate(err, "") } if hasStripeSub { return stipeSubInfo.AccountCountry, nil } else { //if user doesn't have a stripe subscription, return the default stripe account country return ente.DefaultStripeAccountCountry, nil } } // GetUserPlans returns the active plans for a user func (c *BillingController) GetUserPlans(ctx *gin.Context, userID int64) ([]ente.BillingPlan, error) { stripeSubInfo, hasStripeSub, err := c.GetUserStripeSubscriptionInfo(userID) if err != nil { return []ente.BillingPlan{}, stacktrace.Propagate(err, "Failed to get user's subscription country and stripe account") } if hasStripeSub { return c.GetPlansV2(stripeSubInfo.PlanCountry, stripeSubInfo.AccountCountry), nil } else { // user doesn't have a stipe subscription, so return the default account plans for the country the user is from return c.GetPlansV2(network.GetClientCountry(ctx), ente.DefaultStripeAccountCountry), nil } } // GetSubscription returns the current subscription for a user if any func (c *BillingController) GetSubscription(ctx *gin.Context, userID int64) (ente.Subscription, error) { s, err := c.BillingRepo.GetUserSubscription(userID) if err != nil { return ente.Subscription{}, stacktrace.Propagate(err, "") } plan, err := c.getPlanForCountry(s, network.GetClientCountry(ctx)) if err != nil { return ente.Subscription{}, stacktrace.Propagate(err, "") } s.Price = plan.Price s.Period = plan.Period return s, nil } func (c *BillingController) GetRedirectURL(ctx *gin.Context) (string, error) { whitelistedRedirectURLs := viper.GetStringSlice("stripe.whitelisted-redirect-urls") redirectURL := ctx.Query("redirectURL") if len(redirectURL) > 0 && redirectURL[len(redirectURL)-1:] == "/" { // Ignore the trailing slash redirectURL = redirectURL[:len(redirectURL)-1] } for _, ar := range whitelistedRedirectURLs { if ar == redirectURL { return ar, nil } } return "", stacktrace.Propagate(ente.ErrBadRequest, fmt.Sprintf("not a whitelistedRedirectURL- %s", redirectURL)) } // GetActiveSubscription returns user's active subscription or throws a error if no active subscription func (c *BillingController) GetActiveSubscription(userID int64) (ente.Subscription, error) { subscription, err := c.BillingRepo.GetUserSubscription(userID) if errors.Is(err, sql.ErrNoRows) { return subscription, ente.ErrNoActiveSubscription } if err != nil { return subscription, stacktrace.Propagate(err, "") } expiryBuffer := int64(0) if value, ok := billing.ProviderToExpiryGracePeriodMap[subscription.PaymentProvider]; ok { expiryBuffer = value } if (subscription.ExpiryTime + expiryBuffer) < time.Microseconds() { return subscription, ente.ErrNoActiveSubscription } return subscription, nil } // IsActivePayingSubscriber validates if the current user is paying customer with active subscription func (c *BillingController) IsActivePayingSubscriber(userID int64) error { subscription, err := c.GetActiveSubscription(userID) var subErr error if err != nil { subErr = stacktrace.Propagate(err, "") } else if !billing.IsActivePaidPlan(subscription) { subErr = ente.ErrSharingDisabledForFreeAccounts } if subErr != nil && (errors.Is(subErr, ente.ErrNoActiveSubscription) || errors.Is(subErr, ente.ErrSharingDisabledForFreeAccounts)) { storage, storeErr := c.StorageBonusRepo.GetPaidAddonSurplusStorage(context.Background(), userID) if storeErr != nil { return storeErr } if *storage > 0 { return nil } } return nil } // HasActiveSelfOrFamilySubscription validates if the user or user's family admin has active subscription func (c *BillingController) HasActiveSelfOrFamilySubscription(userID int64) error { var subscriptionUserID int64 familyAdminID, err := c.UserRepo.GetFamilyAdminID(userID) if err != nil { return stacktrace.Propagate(err, "") } if familyAdminID != nil { subscriptionUserID = *familyAdminID } else { subscriptionUserID = userID } _, err = c.GetActiveSubscription(subscriptionUserID) if err != nil { if errors.Is(err, ente.ErrNoActiveSubscription) { storage, storeErr := c.StorageBonusRepo.GetPaidAddonSurplusStorage(context.Background(), subscriptionUserID) if storeErr != nil { return storeErr } if *storage > 0 { return nil } } return stacktrace.Propagate(err, "") } return nil } func (c *BillingController) GetUserStripeSubscriptionInfo(userID int64) (ente.StripeSubscriptionInfo, bool, error) { s, err := c.BillingRepo.GetUserSubscription(userID) if err != nil { return ente.StripeSubscriptionInfo{}, false, stacktrace.Propagate(err, "") } // skipping country code extraction for non-stripe subscriptions // as they have same product id across countries and hence can't be distinquished if s.PaymentProvider != ente.Stripe { return ente.StripeSubscriptionInfo{}, false, nil } _, countryCode, err := c.getPlanWithCountry(s) if err != nil { return ente.StripeSubscriptionInfo{}, false, stacktrace.Propagate(err, "") } return ente.StripeSubscriptionInfo{PlanCountry: countryCode, AccountCountry: s.Attributes.StripeAccountCountry}, true, nil } // VerifySubscription verifies and returns the verified subscription func (c *BillingController) VerifySubscription( userID int64, paymentProvider ente.PaymentProvider, productID string, verificationData string) (ente.Subscription, error) { if productID == ente.FreePlanProductID { return c.BillingRepo.GetUserSubscription(userID) } var newSubscription ente.Subscription var err error switch paymentProvider { case ente.PlayStore: newSubscription, err = c.PlayStoreController.GetVerifiedSubscription(userID, productID, verificationData) case ente.AppStore: newSubscription, err = c.AppStoreController.GetVerifiedSubscription(userID, productID, verificationData) case ente.Stripe: newSubscription, err = c.StripeController.GetVerifiedSubscription(userID, verificationData) default: err = stacktrace.Propagate(ente.ErrBadRequest, "") } if err != nil { return ente.Subscription{}, stacktrace.Propagate(err, "") } currentSubscription, err := c.BillingRepo.GetUserSubscription(userID) if err != nil { return ente.Subscription{}, stacktrace.Propagate(err, "") } newSubscriptionExpiresSooner := newSubscription.ExpiryTime < currentSubscription.ExpiryTime isUpgradingFromFreePlan := currentSubscription.ProductID == ente.FreePlanProductID hasChangedProductID := currentSubscription.ProductID != newSubscription.ProductID isOutdatedPurchase := !isUpgradingFromFreePlan && !hasChangedProductID && newSubscriptionExpiresSooner if isOutdatedPurchase { // User is reporting an outdated purchase that was already verified // no-op log.Info("Outdated purchase reported") return currentSubscription, nil } if newSubscription.Storage < currentSubscription.Storage { canDowngrade, canDowngradeErr := c.CommonBillCtrl.CanDowngradeToGivenStorage(newSubscription.Storage, userID) if canDowngradeErr != nil { return ente.Subscription{}, stacktrace.Propagate(canDowngradeErr, "") } if !canDowngrade { return ente.Subscription{}, stacktrace.Propagate(ente.ErrCannotDowngrade, "") } log.Info("Usage is good") } if newSubscription.OriginalTransactionID != "" && newSubscription.OriginalTransactionID != "none" { existingSub, existingSubErr := c.BillingRepo.GetSubscriptionForTransaction(newSubscription.OriginalTransactionID, paymentProvider) if existingSubErr != nil { if errors.Is(existingSubErr, sql.ErrNoRows) { log.Info("No subscription created yet") } else { log.Info("Something went wrong") log.WithError(existingSubErr).Error("GetSubscriptionForTransaction failed") return ente.Subscription{}, stacktrace.Propagate(existingSubErr, "") } } else { if existingSub.UserID != userID { log.WithFields(log.Fields{ "original_transaction_id": existingSub.OriginalTransactionID, "existing_user": existingSub.UserID, "current_user": userID, }).Error("Subscription for given transactionID is attached with different user") log.Info("Subscription attached to different user") return ente.Subscription{}, stacktrace.Propagate(&ente.ErrSubscriptionAlreadyClaimed, fmt.Sprintf("Subscription with txn id %s already associated with user %d", newSubscription.OriginalTransactionID, existingSub.UserID)) } } } err = c.BillingRepo.ReplaceSubscription( currentSubscription.ID, newSubscription, ) if err != nil { return ente.Subscription{}, stacktrace.Propagate(err, "") } log.Info("Replaced subscription") newSubscription.ID = currentSubscription.ID if paymentProvider == ente.PlayStore && newSubscription.OriginalTransactionID != currentSubscription.OriginalTransactionID { // Acknowledge to PlayStore in case of upgrades/downgrades/renewals err = c.PlayStoreController.AcknowledgeSubscription(newSubscription.ProductID, verificationData) if err != nil { log.Error("Error acknowledging subscription ", err) } } if isUpgradingFromFreePlan { go func() { amount := "unknown" plan, _, err := c.getPlanWithCountry(newSubscription) if err != nil { log.Error(err) } else { amount = plan.Price } c.DiscordController.NotifyNewSub(userID, string(paymentProvider), amount) }() go func() { c.EmailNotificationCtrl.OnAccountUpgrade(userID) }() } log.Info("Returning new subscription with ID " + strconv.FormatInt(newSubscription.ID, 10)) return newSubscription, nil } func (c *BillingController) getAllPlans(countryCode string, stripeAccountCountry ente.StripeAccountCountry) []ente.BillingPlan { if array.StringInList(countryCode, billing.CountriesInEU) { countryCode = "EU" } countryWisePlans := c.BillingPlansPerAccount[stripeAccountCountry] if plans, found := countryWisePlans[countryCode]; found { return plans } // unable to find plans for given country code, return plans for default country defaultCountry := billing.GetDefaultPlanCountry() return countryWisePlans[defaultCountry] } func (c *BillingController) UpdateBillingEmail(userID int64, newEmail string) error { subscription, err := c.BillingRepo.GetUserSubscription(userID) if err != nil { return stacktrace.Propagate(err, "") } hasStripeSubscription := subscription.PaymentProvider == ente.Stripe if hasStripeSubscription { err = c.StripeController.UpdateBillingEmail(subscription, newEmail) if err != nil { return stacktrace.Propagate(err, "") } } return nil } func (c *BillingController) UpdateSubscription(r ente.UpdateSubscriptionRequest) error { subscription, err := c.BillingRepo.GetUserSubscription(r.UserID) if err != nil { return stacktrace.Propagate(err, "") } newSubscription := ente.Subscription{ Storage: r.Storage, ExpiryTime: r.ExpiryTime, ProductID: r.ProductID, PaymentProvider: r.PaymentProvider, OriginalTransactionID: r.TransactionID, Attributes: r.Attributes, } err = c.BillingRepo.ReplaceSubscription(subscription.ID, newSubscription) if err != nil { return stacktrace.Propagate(err, "") } err = c.BillingRepo.LogAdminTriggeredSubscriptionUpdate(r) return stacktrace.Propagate(err, "") } func (c *BillingController) HandleAccountDeletion(ctx context.Context, userID int64, logger *log.Entry) (isCancelled bool, err error) { logger.Info("updating billing on account deletion") subscription, err := c.BillingRepo.GetUserSubscription(userID) if err != nil { return false, stacktrace.Propagate(err, "") } billingLogger := logger.WithFields(log.Fields{ "customer_id": subscription.Attributes.CustomerID, "is_cancelled": subscription.Attributes.IsCancelled, "original_txn_id": subscription.OriginalTransactionID, "payment_provider": subscription.PaymentProvider, "product_id": subscription.ProductID, "stripe_account_country": subscription.Attributes.StripeAccountCountry, }) billingLogger.Info("subscription fetched") // user on free plan, no action required if subscription.ProductID == ente.FreePlanProductID { billingLogger.Info("user on free plan") return true, nil } // The word "family" here is a misnomer - these are some manually created // accounts for very early adopters, and are unrelated to Family Plans. // Cancelation of these accounts will require manual intervention. Ideally, // we should never be deleting such accounts. if subscription.ProductID == ente.FamilyPlanProductID || subscription.ProductID == "" { return false, stacktrace.NewError(fmt.Sprintf("unexpected product id %s", subscription.ProductID), "") } isCancelled = subscription.Attributes.IsCancelled // delete customer data from Stripe if user is on paid plan. if subscription.PaymentProvider == ente.Stripe { err = c.StripeController.CancelSubAndDeleteCustomer(subscription, billingLogger) if err != nil { return false, stacktrace.Propagate(err, "") } // on customer deletion, subscription is automatically cancelled isCancelled = true } else if subscription.PaymentProvider == ente.AppStore || subscription.PaymentProvider == ente.PlayStore { logger.Info("Updating originalTransactionID for app/playStore provider") err := c.BillingRepo.UpdateTransactionIDOnDeletion(userID) if err != nil { return false, stacktrace.Propagate(err, "") } } return isCancelled, nil } func (c *BillingController) getPlanWithCountry(s ente.Subscription) (ente.BillingPlan, string, error) { var allPlans ente.BillingPlansPerCountry if s.PaymentProvider == ente.Stripe { allPlans = c.BillingPlansPerAccount[s.Attributes.StripeAccountCountry] } else { allPlans = c.BillingPlansPerAccount[ente.DefaultStripeAccountCountry] } subProductID := s.ProductID for country, plans := range allPlans { for _, plan := range plans { if s.PaymentProvider == ente.Stripe && subProductID == plan.StripeID { return plan, country, nil } else if s.PaymentProvider == ente.PlayStore && subProductID == plan.AndroidID { return plan, country, nil } else if s.PaymentProvider == ente.AppStore && subProductID == plan.IOSID { return plan, country, nil } else if (s.PaymentProvider == ente.BitPay || s.PaymentProvider == ente.Paypal) && subProductID == plan.ID { return plan, country, nil } } } if s.ProductID == ente.FreePlanProductID || s.ProductID == ente.FamilyPlanProductID { return ente.BillingPlan{Period: ente.PeriodYear}, "", nil } return ente.BillingPlan{}, "", stacktrace.Propagate(ente.ErrNotFound, "unable to get plan for subscription") } func (c *BillingController) getPlanForCountry(s ente.Subscription, countryCode string) (ente.BillingPlan, error) { var allPlans []ente.BillingPlan if s.PaymentProvider == ente.Stripe { allPlans = c.getAllPlans(countryCode, s.Attributes.StripeAccountCountry) } else { allPlans = c.getAllPlans(countryCode, ente.DefaultStripeAccountCountry) } subProductID := s.ProductID for _, plan := range allPlans { if s.PaymentProvider == ente.Stripe && subProductID == plan.StripeID { return plan, nil } else if s.PaymentProvider == ente.PlayStore && subProductID == plan.AndroidID { return plan, nil } else if s.PaymentProvider == ente.AppStore && subProductID == plan.IOSID { return plan, nil } else if (s.PaymentProvider == ente.BitPay || s.PaymentProvider == ente.Paypal) && subProductID == plan.ID { return plan, nil } } if s.ProductID == ente.FreePlanProductID || s.ProductID == ente.FamilyPlanProductID { return ente.BillingPlan{Period: ente.PeriodYear}, nil } // If request has a different `countryCode` because the user is traveling, and we're unable to find a plan for that country, // fallback to the previous logic for finding a plan. plan, _, err := c.getPlanWithCountry(s) if err != nil { return ente.BillingPlan{}, stacktrace.Propagate(err, "") } return plan, nil } func contains(planIDs []string, planID string) bool { for _, id := range planIDs { if id == planID { return true } } return false }