billing.go 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477
  1. package controller
  2. import (
  3. "context"
  4. "database/sql"
  5. "errors"
  6. "fmt"
  7. "github.com/ente-io/museum/pkg/controller/commonbilling"
  8. "strconv"
  9. "github.com/ente-io/museum/pkg/repo/storagebonus"
  10. "github.com/ente-io/museum/pkg/controller/discord"
  11. "github.com/ente-io/museum/pkg/controller/email"
  12. "github.com/ente-io/museum/pkg/utils/array"
  13. "github.com/ente-io/museum/pkg/utils/billing"
  14. "github.com/ente-io/museum/pkg/utils/network"
  15. "github.com/ente-io/museum/pkg/utils/time"
  16. "github.com/ente-io/stacktrace"
  17. "github.com/gin-gonic/gin"
  18. log "github.com/sirupsen/logrus"
  19. "github.com/spf13/viper"
  20. "github.com/ente-io/museum/ente"
  21. "github.com/ente-io/museum/pkg/repo"
  22. )
  23. // BillingController provides abstractions for handling billing related queries
  24. type BillingController struct {
  25. BillingPlansPerAccount ente.BillingPlansPerAccount
  26. BillingRepo *repo.BillingRepository
  27. UserRepo *repo.UserRepository
  28. UsageRepo *repo.UsageRepository
  29. StorageBonusRepo *storagebonus.Repository
  30. AppStoreController *AppStoreController
  31. PlayStoreController *PlayStoreController
  32. StripeController *StripeController
  33. DiscordController *discord.DiscordController
  34. EmailNotificationCtrl *email.EmailNotificationController
  35. CommonBillCtrl *commonbilling.Controller
  36. }
  37. // Return a new instance of BillingController
  38. func NewBillingController(
  39. plans ente.BillingPlansPerAccount,
  40. appStoreController *AppStoreController,
  41. playStoreController *PlayStoreController,
  42. stripeController *StripeController,
  43. discordController *discord.DiscordController,
  44. emailNotificationCtrl *email.EmailNotificationController,
  45. billingRepo *repo.BillingRepository,
  46. userRepo *repo.UserRepository,
  47. usageRepo *repo.UsageRepository,
  48. storageBonusRepo *storagebonus.Repository,
  49. commonBillCtrl *commonbilling.Controller,
  50. ) *BillingController {
  51. return &BillingController{
  52. BillingPlansPerAccount: plans,
  53. BillingRepo: billingRepo,
  54. UserRepo: userRepo,
  55. UsageRepo: usageRepo,
  56. AppStoreController: appStoreController,
  57. PlayStoreController: playStoreController,
  58. StripeController: stripeController,
  59. DiscordController: discordController,
  60. EmailNotificationCtrl: emailNotificationCtrl,
  61. StorageBonusRepo: storageBonusRepo,
  62. CommonBillCtrl: commonBillCtrl,
  63. }
  64. }
  65. // GetPlansV2 returns the available subscription plans for the given country and stripe account
  66. func (c *BillingController) GetPlansV2(countryCode string, stripeAccountCountry ente.StripeAccountCountry) []ente.BillingPlan {
  67. plans := c.getAllPlans(countryCode, stripeAccountCountry)
  68. result := make([]ente.BillingPlan, 0)
  69. ids := billing.GetActivePlanIDs()
  70. for _, plan := range plans {
  71. if contains(ids, plan.ID) {
  72. result = append(result, plan)
  73. }
  74. }
  75. return result
  76. }
  77. // GetStripeAccountCountry returns the stripe account country the user's existing plan is from
  78. // if he doesn't have a stripe subscription then ente.DefaultStripeAccountCountry is returned
  79. func (c *BillingController) GetStripeAccountCountry(userID int64) (ente.StripeAccountCountry, error) {
  80. subscription, err := c.BillingRepo.GetUserSubscription(userID)
  81. if err != nil {
  82. return "", stacktrace.Propagate(err, "")
  83. }
  84. if subscription.PaymentProvider != ente.Stripe {
  85. //if user doesn't have a stripe subscription, return the default stripe account country
  86. return ente.DefaultStripeAccountCountry, nil
  87. } else {
  88. return subscription.Attributes.StripeAccountCountry, nil
  89. }
  90. }
  91. // GetUserPlans returns the active plans for a user
  92. func (c *BillingController) GetUserPlans(ctx *gin.Context, userID int64) ([]ente.BillingPlan, error) {
  93. stripeAccountCountry, err := c.GetStripeAccountCountry(userID)
  94. if err != nil {
  95. return []ente.BillingPlan{}, stacktrace.Propagate(err, "Failed to get user's country stripe account")
  96. }
  97. // always return the plans based on the user's country determined by the IP
  98. return c.GetPlansV2(network.GetClientCountry(ctx), stripeAccountCountry), nil
  99. }
  100. // GetSubscription returns the current subscription for a user if any
  101. func (c *BillingController) GetSubscription(ctx *gin.Context, userID int64) (ente.Subscription, error) {
  102. s, err := c.BillingRepo.GetUserSubscription(userID)
  103. if err != nil {
  104. return ente.Subscription{}, stacktrace.Propagate(err, "")
  105. }
  106. plan, err := c.getPlanForCountry(s, network.GetClientCountry(ctx))
  107. if err != nil {
  108. return ente.Subscription{}, stacktrace.Propagate(err, "")
  109. }
  110. s.Price = plan.Price
  111. s.Period = plan.Period
  112. return s, nil
  113. }
  114. func (c *BillingController) GetRedirectURL(ctx *gin.Context) (string, error) {
  115. whitelistedRedirectURLs := viper.GetStringSlice("stripe.whitelisted-redirect-urls")
  116. redirectURL := ctx.Query("redirectURL")
  117. if len(redirectURL) > 0 && redirectURL[len(redirectURL)-1:] == "/" { // Ignore the trailing slash
  118. redirectURL = redirectURL[:len(redirectURL)-1]
  119. }
  120. for _, ar := range whitelistedRedirectURLs {
  121. if ar == redirectURL {
  122. return ar, nil
  123. }
  124. }
  125. return "", stacktrace.Propagate(ente.ErrBadRequest, fmt.Sprintf("not a whitelistedRedirectURL- %s", redirectURL))
  126. }
  127. // GetActiveSubscription returns user's active subscription or throws a error if no active subscription
  128. func (c *BillingController) GetActiveSubscription(userID int64) (ente.Subscription, error) {
  129. subscription, err := c.BillingRepo.GetUserSubscription(userID)
  130. if errors.Is(err, sql.ErrNoRows) {
  131. return subscription, ente.ErrNoActiveSubscription
  132. }
  133. if err != nil {
  134. return subscription, stacktrace.Propagate(err, "")
  135. }
  136. expiryBuffer := int64(0)
  137. if value, ok := billing.ProviderToExpiryGracePeriodMap[subscription.PaymentProvider]; ok {
  138. expiryBuffer = value
  139. }
  140. if (subscription.ExpiryTime + expiryBuffer) < time.Microseconds() {
  141. return subscription, ente.ErrNoActiveSubscription
  142. }
  143. return subscription, nil
  144. }
  145. // IsActivePayingSubscriber validates if the current user is paying customer with active subscription
  146. func (c *BillingController) IsActivePayingSubscriber(userID int64) error {
  147. subscription, err := c.GetActiveSubscription(userID)
  148. var subErr error
  149. if err != nil {
  150. subErr = stacktrace.Propagate(err, "")
  151. } else if !billing.IsActivePaidPlan(subscription) {
  152. subErr = ente.ErrSharingDisabledForFreeAccounts
  153. }
  154. if subErr != nil && (errors.Is(subErr, ente.ErrNoActiveSubscription) || errors.Is(subErr, ente.ErrSharingDisabledForFreeAccounts)) {
  155. storage, storeErr := c.StorageBonusRepo.GetPaidAddonSurplusStorage(context.Background(), userID)
  156. if storeErr != nil {
  157. return storeErr
  158. }
  159. if *storage > 0 {
  160. return nil
  161. }
  162. }
  163. return nil
  164. }
  165. // HasActiveSelfOrFamilySubscription validates if the user or user's family admin has active subscription
  166. func (c *BillingController) HasActiveSelfOrFamilySubscription(userID int64) error {
  167. var subscriptionUserID int64
  168. familyAdminID, err := c.UserRepo.GetFamilyAdminID(userID)
  169. if err != nil {
  170. return stacktrace.Propagate(err, "")
  171. }
  172. if familyAdminID != nil {
  173. subscriptionUserID = *familyAdminID
  174. } else {
  175. subscriptionUserID = userID
  176. }
  177. _, err = c.GetActiveSubscription(subscriptionUserID)
  178. if err != nil {
  179. if errors.Is(err, ente.ErrNoActiveSubscription) {
  180. storage, storeErr := c.StorageBonusRepo.GetPaidAddonSurplusStorage(context.Background(), subscriptionUserID)
  181. if storeErr != nil {
  182. return storeErr
  183. }
  184. if *storage > 0 {
  185. return nil
  186. }
  187. }
  188. return stacktrace.Propagate(err, "")
  189. }
  190. return nil
  191. }
  192. // VerifySubscription verifies and returns the verified subscription
  193. func (c *BillingController) VerifySubscription(
  194. userID int64,
  195. paymentProvider ente.PaymentProvider,
  196. productID string,
  197. verificationData string) (ente.Subscription, error) {
  198. if productID == ente.FreePlanProductID {
  199. return c.BillingRepo.GetUserSubscription(userID)
  200. }
  201. var newSubscription ente.Subscription
  202. var err error
  203. switch paymentProvider {
  204. case ente.PlayStore:
  205. newSubscription, err = c.PlayStoreController.GetVerifiedSubscription(userID, productID, verificationData)
  206. case ente.AppStore:
  207. newSubscription, err = c.AppStoreController.GetVerifiedSubscription(userID, productID, verificationData)
  208. case ente.Stripe:
  209. newSubscription, err = c.StripeController.GetVerifiedSubscription(userID, verificationData)
  210. default:
  211. err = stacktrace.Propagate(ente.ErrBadRequest, "")
  212. }
  213. if err != nil {
  214. return ente.Subscription{}, stacktrace.Propagate(err, "")
  215. }
  216. currentSubscription, err := c.BillingRepo.GetUserSubscription(userID)
  217. if err != nil {
  218. return ente.Subscription{}, stacktrace.Propagate(err, "")
  219. }
  220. newSubscriptionExpiresSooner := newSubscription.ExpiryTime < currentSubscription.ExpiryTime
  221. isUpgradingFromFreePlan := currentSubscription.ProductID == ente.FreePlanProductID
  222. hasChangedProductID := currentSubscription.ProductID != newSubscription.ProductID
  223. isOutdatedPurchase := !isUpgradingFromFreePlan && !hasChangedProductID && newSubscriptionExpiresSooner
  224. if isOutdatedPurchase {
  225. // User is reporting an outdated purchase that was already verified
  226. // no-op
  227. log.Info("Outdated purchase reported")
  228. return currentSubscription, nil
  229. }
  230. if newSubscription.Storage < currentSubscription.Storage {
  231. canDowngrade, canDowngradeErr := c.CommonBillCtrl.CanDowngradeToGivenStorage(newSubscription.Storage, userID)
  232. if canDowngradeErr != nil {
  233. return ente.Subscription{}, stacktrace.Propagate(canDowngradeErr, "")
  234. }
  235. if !canDowngrade {
  236. return ente.Subscription{}, stacktrace.Propagate(ente.ErrCannotDowngrade, "")
  237. }
  238. log.Info("Usage is good")
  239. }
  240. if newSubscription.OriginalTransactionID != "" && newSubscription.OriginalTransactionID != "none" {
  241. existingSub, existingSubErr := c.BillingRepo.GetSubscriptionForTransaction(newSubscription.OriginalTransactionID, paymentProvider)
  242. if existingSubErr != nil {
  243. if errors.Is(existingSubErr, sql.ErrNoRows) {
  244. log.Info("No subscription created yet")
  245. } else {
  246. log.Info("Something went wrong")
  247. log.WithError(existingSubErr).Error("GetSubscriptionForTransaction failed")
  248. return ente.Subscription{}, stacktrace.Propagate(existingSubErr, "")
  249. }
  250. } else {
  251. if existingSub.UserID != userID {
  252. log.WithFields(log.Fields{
  253. "original_transaction_id": existingSub.OriginalTransactionID,
  254. "existing_user": existingSub.UserID,
  255. "current_user": userID,
  256. }).Error("Subscription for given transactionID is attached with different user")
  257. log.Info("Subscription attached to different user")
  258. return ente.Subscription{}, stacktrace.Propagate(&ente.ErrSubscriptionAlreadyClaimed,
  259. fmt.Sprintf("Subscription with txn id %s already associated with user %d", newSubscription.OriginalTransactionID, existingSub.UserID))
  260. }
  261. }
  262. }
  263. err = c.BillingRepo.ReplaceSubscription(
  264. currentSubscription.ID,
  265. newSubscription,
  266. )
  267. if err != nil {
  268. return ente.Subscription{}, stacktrace.Propagate(err, "")
  269. }
  270. log.Info("Replaced subscription")
  271. newSubscription.ID = currentSubscription.ID
  272. if paymentProvider == ente.PlayStore &&
  273. newSubscription.OriginalTransactionID != currentSubscription.OriginalTransactionID {
  274. // Acknowledge to PlayStore in case of upgrades/downgrades/renewals
  275. err = c.PlayStoreController.AcknowledgeSubscription(newSubscription.ProductID, verificationData)
  276. if err != nil {
  277. log.Error("Error acknowledging subscription ", err)
  278. }
  279. }
  280. if isUpgradingFromFreePlan {
  281. go func() {
  282. amount := "unknown"
  283. plan, _, err := c.getPlanWithCountry(newSubscription)
  284. if err != nil {
  285. log.Error(err)
  286. } else {
  287. amount = plan.Price
  288. }
  289. c.DiscordController.NotifyNewSub(userID, string(paymentProvider), amount)
  290. }()
  291. go func() {
  292. c.EmailNotificationCtrl.OnAccountUpgrade(userID)
  293. }()
  294. }
  295. log.Info("Returning new subscription with ID " + strconv.FormatInt(newSubscription.ID, 10))
  296. return newSubscription, nil
  297. }
  298. func (c *BillingController) getAllPlans(countryCode string, stripeAccountCountry ente.StripeAccountCountry) []ente.BillingPlan {
  299. if array.StringInList(countryCode, billing.CountriesInEU) {
  300. countryCode = "EU"
  301. }
  302. countryWisePlans := c.BillingPlansPerAccount[stripeAccountCountry]
  303. if plans, found := countryWisePlans[countryCode]; found {
  304. return plans
  305. }
  306. // unable to find plans for given country code, return plans for default country
  307. defaultCountry := billing.GetDefaultPlanCountry()
  308. return countryWisePlans[defaultCountry]
  309. }
  310. func (c *BillingController) UpdateBillingEmail(userID int64, newEmail string) error {
  311. subscription, err := c.BillingRepo.GetUserSubscription(userID)
  312. if err != nil {
  313. return stacktrace.Propagate(err, "")
  314. }
  315. hasStripeSubscription := subscription.PaymentProvider == ente.Stripe
  316. if hasStripeSubscription {
  317. err = c.StripeController.UpdateBillingEmail(subscription, newEmail)
  318. if err != nil {
  319. return stacktrace.Propagate(err, "")
  320. }
  321. }
  322. return nil
  323. }
  324. func (c *BillingController) UpdateSubscription(r ente.UpdateSubscriptionRequest) error {
  325. subscription, err := c.BillingRepo.GetUserSubscription(r.UserID)
  326. if err != nil {
  327. return stacktrace.Propagate(err, "")
  328. }
  329. newSubscription := ente.Subscription{
  330. Storage: r.Storage,
  331. ExpiryTime: r.ExpiryTime,
  332. ProductID: r.ProductID,
  333. PaymentProvider: r.PaymentProvider,
  334. OriginalTransactionID: r.TransactionID,
  335. Attributes: r.Attributes,
  336. }
  337. err = c.BillingRepo.ReplaceSubscription(subscription.ID, newSubscription)
  338. if err != nil {
  339. return stacktrace.Propagate(err, "")
  340. }
  341. err = c.BillingRepo.LogAdminTriggeredSubscriptionUpdate(r)
  342. return stacktrace.Propagate(err, "")
  343. }
  344. func (c *BillingController) HandleAccountDeletion(ctx context.Context, userID int64, logger *log.Entry) (isCancelled bool, err error) {
  345. logger.Info("updating billing on account deletion")
  346. subscription, err := c.BillingRepo.GetUserSubscription(userID)
  347. if err != nil {
  348. return false, stacktrace.Propagate(err, "")
  349. }
  350. billingLogger := logger.WithFields(log.Fields{
  351. "customer_id": subscription.Attributes.CustomerID,
  352. "is_cancelled": subscription.Attributes.IsCancelled,
  353. "original_txn_id": subscription.OriginalTransactionID,
  354. "payment_provider": subscription.PaymentProvider,
  355. "product_id": subscription.ProductID,
  356. "stripe_account_country": subscription.Attributes.StripeAccountCountry,
  357. })
  358. billingLogger.Info("subscription fetched")
  359. // user on free plan, no action required
  360. if subscription.ProductID == ente.FreePlanProductID {
  361. billingLogger.Info("user on free plan")
  362. return true, nil
  363. }
  364. // The word "family" here is a misnomer - these are some manually created
  365. // accounts for very early adopters, and are unrelated to Family Plans.
  366. // Cancelation of these accounts will require manual intervention. Ideally,
  367. // we should never be deleting such accounts.
  368. if subscription.ProductID == ente.FamilyPlanProductID || subscription.ProductID == "" {
  369. return false, stacktrace.NewError(fmt.Sprintf("unexpected product id %s", subscription.ProductID), "")
  370. }
  371. isCancelled = subscription.Attributes.IsCancelled
  372. // delete customer data from Stripe if user is on paid plan.
  373. if subscription.PaymentProvider == ente.Stripe {
  374. err = c.StripeController.CancelSubAndDeleteCustomer(subscription, billingLogger)
  375. if err != nil {
  376. return false, stacktrace.Propagate(err, "")
  377. }
  378. // on customer deletion, subscription is automatically cancelled
  379. isCancelled = true
  380. } else if subscription.PaymentProvider == ente.AppStore || subscription.PaymentProvider == ente.PlayStore {
  381. logger.Info("Updating originalTransactionID for app/playStore provider")
  382. err := c.BillingRepo.UpdateTransactionIDOnDeletion(userID)
  383. if err != nil {
  384. return false, stacktrace.Propagate(err, "")
  385. }
  386. }
  387. return isCancelled, nil
  388. }
  389. func (c *BillingController) getPlanWithCountry(s ente.Subscription) (ente.BillingPlan, string, error) {
  390. var allPlans ente.BillingPlansPerCountry
  391. if s.PaymentProvider == ente.Stripe {
  392. allPlans = c.BillingPlansPerAccount[s.Attributes.StripeAccountCountry]
  393. } else {
  394. allPlans = c.BillingPlansPerAccount[ente.DefaultStripeAccountCountry]
  395. }
  396. subProductID := s.ProductID
  397. for country, plans := range allPlans {
  398. for _, plan := range plans {
  399. if s.PaymentProvider == ente.Stripe && subProductID == plan.StripeID {
  400. return plan, country, nil
  401. } else if s.PaymentProvider == ente.PlayStore && subProductID == plan.AndroidID {
  402. return plan, country, nil
  403. } else if s.PaymentProvider == ente.AppStore && subProductID == plan.IOSID {
  404. return plan, country, nil
  405. } else if (s.PaymentProvider == ente.BitPay || s.PaymentProvider == ente.Paypal) && subProductID == plan.ID {
  406. return plan, country, nil
  407. }
  408. }
  409. }
  410. if s.ProductID == ente.FreePlanProductID || s.ProductID == ente.FamilyPlanProductID {
  411. return ente.BillingPlan{Period: ente.PeriodYear}, "", nil
  412. }
  413. return ente.BillingPlan{}, "", stacktrace.Propagate(ente.ErrNotFound, "unable to get plan for subscription")
  414. }
  415. func (c *BillingController) getPlanForCountry(s ente.Subscription, countryCode string) (ente.BillingPlan, error) {
  416. var allPlans []ente.BillingPlan
  417. if s.PaymentProvider == ente.Stripe {
  418. allPlans = c.getAllPlans(countryCode, s.Attributes.StripeAccountCountry)
  419. } else {
  420. allPlans = c.getAllPlans(countryCode, ente.DefaultStripeAccountCountry)
  421. }
  422. subProductID := s.ProductID
  423. for _, plan := range allPlans {
  424. if s.PaymentProvider == ente.Stripe && subProductID == plan.StripeID {
  425. return plan, nil
  426. } else if s.PaymentProvider == ente.PlayStore && subProductID == plan.AndroidID {
  427. return plan, nil
  428. } else if s.PaymentProvider == ente.AppStore && subProductID == plan.IOSID {
  429. return plan, nil
  430. } else if (s.PaymentProvider == ente.BitPay || s.PaymentProvider == ente.Paypal) && subProductID == plan.ID {
  431. return plan, nil
  432. }
  433. }
  434. if s.ProductID == ente.FreePlanProductID || s.ProductID == ente.FamilyPlanProductID {
  435. return ente.BillingPlan{Period: ente.PeriodYear}, nil
  436. }
  437. // If request has a different `countryCode` because the user is traveling, and we're unable to find a plan for that country,
  438. // fallback to the previous logic for finding a plan.
  439. plan, _, err := c.getPlanWithCountry(s)
  440. if err != nil {
  441. return ente.BillingPlan{}, stacktrace.Propagate(err, "")
  442. }
  443. return plan, nil
  444. }
  445. func contains(planIDs []string, planID string) bool {
  446. for _, id := range planIDs {
  447. if id == planID {
  448. return true
  449. }
  450. }
  451. return false
  452. }