stripe.go 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714
  1. package controller
  2. import (
  3. "database/sql"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "net/http"
  8. "strconv"
  9. "time"
  10. "github.com/ente-io/museum/pkg/controller/commonbilling"
  11. "github.com/ente-io/museum/pkg/controller/discord"
  12. "github.com/ente-io/museum/pkg/controller/offer"
  13. "github.com/ente-io/museum/pkg/repo/storagebonus"
  14. "github.com/ente-io/museum/ente"
  15. emailCtrl "github.com/ente-io/museum/pkg/controller/email"
  16. "github.com/ente-io/museum/pkg/repo"
  17. "github.com/ente-io/museum/pkg/utils/billing"
  18. "github.com/ente-io/museum/pkg/utils/email"
  19. "github.com/ente-io/stacktrace"
  20. log "github.com/sirupsen/logrus"
  21. "github.com/spf13/viper"
  22. "github.com/stripe/stripe-go/v72"
  23. "github.com/stripe/stripe-go/v72/client"
  24. "github.com/stripe/stripe-go/v72/webhook"
  25. "golang.org/x/text/currency"
  26. )
  27. // StripeController provides abstractions for handling billing on Stripe
  28. type StripeController struct {
  29. StripeClients ente.StripeClientPerAccount
  30. BillingPlansPerAccount ente.BillingPlansPerAccount
  31. BillingRepo *repo.BillingRepository
  32. FileRepo *repo.FileRepository
  33. UserRepo *repo.UserRepository
  34. StorageBonusRepo *storagebonus.Repository
  35. DiscordController *discord.DiscordController
  36. EmailNotificationCtrl *emailCtrl.EmailNotificationController
  37. OfferController *offer.OfferController
  38. CommonBillCtrl *commonbilling.Controller
  39. }
  40. const BufferPeriodOnPaymentFailureInDays = 7
  41. // Return a new instance of StripeController
  42. func NewStripeController(plans ente.BillingPlansPerAccount, stripeClients ente.StripeClientPerAccount, billingRepo *repo.BillingRepository, fileRepo *repo.FileRepository, userRepo *repo.UserRepository, storageBonusRepo *storagebonus.Repository, discordController *discord.DiscordController, emailNotificationController *emailCtrl.EmailNotificationController, offerController *offer.OfferController, commonBillCtrl *commonbilling.Controller) *StripeController {
  43. return &StripeController{
  44. StripeClients: stripeClients,
  45. BillingRepo: billingRepo,
  46. FileRepo: fileRepo,
  47. UserRepo: userRepo,
  48. BillingPlansPerAccount: plans,
  49. StorageBonusRepo: storageBonusRepo,
  50. DiscordController: discordController,
  51. EmailNotificationCtrl: emailNotificationController,
  52. OfferController: offerController,
  53. CommonBillCtrl: commonBillCtrl,
  54. }
  55. }
  56. // GetCheckoutSession handles the creation of stripe checkout session for subscription purchase
  57. func (c *StripeController) GetCheckoutSession(productID string, userID int64, redirectRootURL string) (string, error) {
  58. if productID == "" {
  59. return "", stacktrace.Propagate(ente.ErrBadRequest, "")
  60. }
  61. subscription, err := c.BillingRepo.GetUserSubscription(userID)
  62. if err != nil {
  63. // error sql.ErrNoRows not possible as user must at least have a free subscription
  64. return "", stacktrace.Propagate(err, "")
  65. }
  66. hasActivePaidSubscription := billing.IsActivePaidPlan(subscription)
  67. hasStripeSubscription := subscription.PaymentProvider == ente.Stripe
  68. if hasActivePaidSubscription {
  69. if hasStripeSubscription {
  70. return "", stacktrace.Propagate(ente.ErrBadRequest, "")
  71. } else if !subscription.Attributes.IsCancelled {
  72. return "", stacktrace.Propagate(ente.ErrBadRequest, "")
  73. }
  74. }
  75. if hasStripeSubscription {
  76. client := c.StripeClients[subscription.Attributes.StripeAccountCountry]
  77. stripeSubscription, err := client.Subscriptions.Get(subscription.OriginalTransactionID, nil)
  78. if err != nil {
  79. return "", stacktrace.Propagate(err, "")
  80. }
  81. if stripeSubscription.Status != stripe.SubscriptionStatusCanceled {
  82. return "", stacktrace.Propagate(ente.ErrBadRequest, "")
  83. }
  84. }
  85. stripeSuccessURL := redirectRootURL + viper.GetString("stripe.path.success")
  86. stripeCancelURL := redirectRootURL + viper.GetString("stripe.path.cancel")
  87. allowPromotionCodes := true
  88. params := &stripe.CheckoutSessionParams{
  89. ClientReferenceID: stripe.String(strconv.FormatInt(userID, 10)),
  90. SuccessURL: stripe.String(stripeSuccessURL),
  91. CancelURL: stripe.String(stripeCancelURL),
  92. Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)),
  93. LineItems: []*stripe.CheckoutSessionLineItemParams{
  94. {
  95. Price: stripe.String(productID),
  96. Quantity: stripe.Int64(1),
  97. },
  98. },
  99. AllowPromotionCodes: &allowPromotionCodes,
  100. }
  101. var stripeClient *client.API
  102. if subscription.PaymentProvider == ente.Stripe {
  103. stripeClient = c.StripeClients[subscription.Attributes.StripeAccountCountry]
  104. // attach the subscription to existing customerID
  105. params.Customer = stripe.String(subscription.Attributes.CustomerID)
  106. } else {
  107. stripeClient = c.StripeClients[ente.DefaultStripeAccountCountry]
  108. user, err := c.UserRepo.Get(userID)
  109. if err != nil {
  110. return "", stacktrace.Propagate(err, "")
  111. }
  112. // attach user's emailID to the checkout session and subsequent subscription bought
  113. params.CustomerEmail = stripe.String(user.Email)
  114. }
  115. s, err := stripeClient.CheckoutSessions.New(params)
  116. if err != nil {
  117. return "", stacktrace.Propagate(err, "")
  118. }
  119. return s.ID, nil
  120. }
  121. // GetVerifiedSubscription verifies and returns the verified subscription
  122. func (c *StripeController) GetVerifiedSubscription(userID int64, sessionID string) (ente.Subscription, error) {
  123. var stripeSubscription stripe.Subscription
  124. var err error
  125. if sessionID != "" {
  126. log.Info("Received session ID: " + sessionID)
  127. // Get verified subscription request was received from success redirect page
  128. stripeSubscription, err = c.getStripeSubscriptionFromSession(userID, sessionID)
  129. } else {
  130. log.Info("Did not receive a session ID")
  131. // Get verified subscription request for a subscription update
  132. stripeSubscription, err = c.getUserStripeSubscription(userID)
  133. }
  134. if err != nil {
  135. return ente.Subscription{}, stacktrace.Propagate(err, "")
  136. }
  137. log.Info("Received stripe subscription with ID: " + stripeSubscription.ID)
  138. subscription, err := c.getEnteSubscriptionFromStripeSubscription(userID, stripeSubscription)
  139. if err != nil {
  140. return ente.Subscription{}, stacktrace.Propagate(err, "")
  141. }
  142. log.Info("Returning ente subscription with ID: " + strconv.FormatInt(subscription.ID, 10))
  143. return subscription, nil
  144. }
  145. func (c *StripeController) HandleUSNotification(payload []byte, header string) error {
  146. event, err := webhook.ConstructEvent(payload, header, viper.GetString("stripe.us.webhook-secret"))
  147. if err != nil {
  148. return stacktrace.Propagate(err, "")
  149. }
  150. return c.handleWebhookEvent(event, ente.StripeUS)
  151. }
  152. func (c *StripeController) HandleINNotification(payload []byte, header string) error {
  153. event, err := webhook.ConstructEvent(payload, header, viper.GetString("stripe.in.webhook-secret"))
  154. if err != nil {
  155. return stacktrace.Propagate(err, "")
  156. }
  157. return c.handleWebhookEvent(event, ente.StripeIN)
  158. }
  159. func (c *StripeController) handleWebhookEvent(event stripe.Event, country ente.StripeAccountCountry) error {
  160. // The event body would already have been logged by the upper layers by the
  161. // time we get here, so we can only handle the events that we care about. In
  162. // case we receive an unexpected event, we do log an error though.
  163. handler := c.findHandlerForEvent(event)
  164. if handler == nil {
  165. log.Error("Received an unexpected webhook from stripe:", event.Type)
  166. return nil
  167. }
  168. eventLog, err := handler(event, country)
  169. if err != nil {
  170. return stacktrace.Propagate(err, "")
  171. }
  172. if eventLog.UserID == 0 {
  173. // Do not try to log if we do not have an associated user. This can
  174. // happen, e.g. with out of order webhooks.
  175. // Or in case of offer application, where events are logged by the Storage Bonus Repo
  176. //
  177. // See: Ignore webhooks received before user has been created
  178. return nil
  179. }
  180. err = c.BillingRepo.LogStripePush(eventLog)
  181. return stacktrace.Propagate(err, "")
  182. }
  183. func (c *StripeController) findHandlerForEvent(event stripe.Event) func(event stripe.Event, country ente.StripeAccountCountry) (ente.StripeEventLog, error) {
  184. switch event.Type {
  185. case "checkout.session.completed":
  186. return c.handleCheckoutSessionCompleted
  187. case "customer.subscription.updated":
  188. return c.handleCustomerSubscriptionUpdated
  189. case "invoice.paid":
  190. return c.handleInvoicePaid
  191. case "payment_intent.payment_failed":
  192. return c.handlePaymentIntentFailed
  193. default:
  194. return nil
  195. }
  196. }
  197. // Payment is successful and the subscription is created.
  198. // You should provision the subscription.
  199. func (c *StripeController) handleCheckoutSessionCompleted(event stripe.Event, country ente.StripeAccountCountry) (ente.StripeEventLog, error) {
  200. var session stripe.CheckoutSession
  201. json.Unmarshal(event.Data.Raw, &session)
  202. if session.ClientReferenceID != "" { // via payments.ente.io, where we inserted the userID
  203. userID, _ := strconv.ParseInt(session.ClientReferenceID, 10, 64)
  204. newSubscription, err := c.GetVerifiedSubscription(userID, session.ID)
  205. if err != nil {
  206. return ente.StripeEventLog{}, stacktrace.Propagate(err, "")
  207. }
  208. stripeSubscription, err := c.getStripeSubscriptionFromSession(userID, session.ID)
  209. if err != nil {
  210. return ente.StripeEventLog{}, stacktrace.Propagate(err, "")
  211. }
  212. currentSubscription, err := c.BillingRepo.GetUserSubscription(userID)
  213. if err != nil {
  214. return ente.StripeEventLog{}, stacktrace.Propagate(err, "")
  215. }
  216. if currentSubscription.ExpiryTime >= newSubscription.ExpiryTime &&
  217. currentSubscription.ProductID != ente.FreePlanProductID {
  218. log.Warn("Webhook is reporting an outdated purchase that was already verified stripeSubscription:", stripeSubscription.ID)
  219. return ente.StripeEventLog{UserID: userID, StripeSubscription: stripeSubscription, Event: event}, nil
  220. }
  221. err = c.BillingRepo.ReplaceSubscription(
  222. currentSubscription.ID,
  223. newSubscription,
  224. )
  225. isUpgradingFromFreePlan := currentSubscription.ProductID == ente.FreePlanProductID
  226. if isUpgradingFromFreePlan {
  227. go func() {
  228. cur := currency.MustParseISO(string(session.Currency))
  229. amount := fmt.Sprintf("%v%v", currency.Symbol(cur), float64(session.AmountTotal)/float64(100))
  230. c.DiscordController.NotifyNewSub(userID, "stripe", amount)
  231. }()
  232. go func() {
  233. c.EmailNotificationCtrl.OnAccountUpgrade(userID)
  234. }()
  235. }
  236. if err != nil {
  237. return ente.StripeEventLog{}, stacktrace.Propagate(err, "")
  238. }
  239. return ente.StripeEventLog{UserID: userID, StripeSubscription: stripeSubscription, Event: event}, nil
  240. } else {
  241. priceID, err := c.getPriceIDFromSession(session.ID)
  242. if err != nil {
  243. return ente.StripeEventLog{}, stacktrace.Propagate(err, "")
  244. }
  245. email := session.CustomerDetails.Email
  246. err = c.OfferController.ApplyOffer(email, priceID)
  247. if err != nil {
  248. return ente.StripeEventLog{}, stacktrace.Propagate(err, "")
  249. }
  250. }
  251. return ente.StripeEventLog{}, nil
  252. }
  253. // Stripe fires this when a subscription starts or changes. For example,
  254. // renewing a subscription, adding a coupon, applying a discount, adding an
  255. // invoice item, and changing plans all trigger this event. In our case, we use
  256. // this only to track plan changes or subscriptions going past due. The rest
  257. // (subscription creations, deletions, renewals and failures) are tracked by
  258. // individual events.
  259. func (c *StripeController) handleCustomerSubscriptionUpdated(event stripe.Event, country ente.StripeAccountCountry) (ente.StripeEventLog, error) {
  260. var stripeSubscription stripe.Subscription
  261. json.Unmarshal(event.Data.Raw, &stripeSubscription)
  262. currentSubscription, err := c.BillingRepo.GetSubscriptionForTransaction(stripeSubscription.ID, ente.Stripe)
  263. if err != nil {
  264. if errors.Is(err, sql.ErrNoRows) {
  265. // See: Ignore webhooks received before user has been created
  266. log.Warn("Webhook is reporting an event for un-verified subscription stripeSubscriptionID:", stripeSubscription.ID)
  267. return ente.StripeEventLog{}, nil
  268. }
  269. return ente.StripeEventLog{}, stacktrace.Propagate(err, "")
  270. }
  271. userID := currentSubscription.UserID
  272. if stripeSubscription.Status == stripe.SubscriptionStatusPastDue {
  273. user, err := c.UserRepo.Get(userID)
  274. if err != nil {
  275. return ente.StripeEventLog{}, stacktrace.Propagate(err, "")
  276. }
  277. err = email.SendTemplatedEmail([]string{user.Email}, "ente", "support@ente.io",
  278. ente.AccountOnHoldEmailSubject, ente.OnHoldTemplate, map[string]interface{}{
  279. "PaymentProvider": "Stripe",
  280. }, nil)
  281. if err != nil {
  282. return ente.StripeEventLog{}, stacktrace.Propagate(err, "")
  283. }
  284. }
  285. newSubscription, err := c.getEnteSubscriptionFromStripeSubscription(userID, stripeSubscription)
  286. if err != nil {
  287. return ente.StripeEventLog{}, stacktrace.Propagate(err, "")
  288. }
  289. // If the customer has changed the plan, we update state in the database. If
  290. // the plan has not changed, we will ignore this webhook and rely on other
  291. // events to update the state
  292. if currentSubscription.ProductID != newSubscription.ProductID {
  293. c.BillingRepo.ReplaceSubscription(currentSubscription.ID, newSubscription)
  294. }
  295. return ente.StripeEventLog{UserID: userID, StripeSubscription: stripeSubscription, Event: event}, nil
  296. }
  297. // Continue to provision the subscription as payments continue to be made.
  298. func (c *StripeController) handleInvoicePaid(event stripe.Event, country ente.StripeAccountCountry) (ente.StripeEventLog, error) {
  299. var invoice stripe.Invoice
  300. json.Unmarshal(event.Data.Raw, &invoice)
  301. stripeSubscriptionID := invoice.Subscription.ID
  302. currentSubscription, err := c.BillingRepo.GetSubscriptionForTransaction(stripeSubscriptionID, ente.Stripe)
  303. if err != nil {
  304. if errors.Is(err, sql.ErrNoRows) {
  305. // See: Ignore webhooks received before user has been created
  306. log.Warn("Webhook is reporting an event for un-verified subscription stripeSubscriptionID:", stripeSubscriptionID)
  307. return ente.StripeEventLog{}, nil
  308. }
  309. return ente.StripeEventLog{}, stacktrace.Propagate(err, "")
  310. }
  311. userID := currentSubscription.UserID
  312. client := c.StripeClients[currentSubscription.Attributes.StripeAccountCountry]
  313. stripeSubscription, err := client.Subscriptions.Get(stripeSubscriptionID, nil)
  314. if err != nil {
  315. return ente.StripeEventLog{}, stacktrace.Propagate(err, "")
  316. }
  317. newExpiryTime := stripeSubscription.CurrentPeriodEnd * 1000 * 1000
  318. if currentSubscription.ExpiryTime == newExpiryTime {
  319. //outdated invoice
  320. log.Warn("Webhook is reporting an outdated purchase that was already verified stripeSubscriptionID:", stripeSubscription.ID)
  321. return ente.StripeEventLog{UserID: userID, StripeSubscription: *stripeSubscription, Event: event}, nil
  322. }
  323. err = c.BillingRepo.UpdateSubscriptionExpiryTime(
  324. currentSubscription.ID, newExpiryTime)
  325. if err != nil {
  326. return ente.StripeEventLog{}, stacktrace.Propagate(err, "")
  327. }
  328. return ente.StripeEventLog{UserID: userID, StripeSubscription: *stripeSubscription, Event: event}, nil
  329. }
  330. // Event used to ONLY handle failures to SEPA payments, since we set
  331. // SubscriptionPaymentBehaviorAllowIncomplete only for SEPA. Other payment modes
  332. // will fail and will be handled synchronously
  333. func (c *StripeController) handlePaymentIntentFailed(event stripe.Event, country ente.StripeAccountCountry) (ente.StripeEventLog, error) {
  334. var paymentIntent stripe.PaymentIntent
  335. json.Unmarshal(event.Data.Raw, &paymentIntent)
  336. isSEPA := paymentIntent.LastPaymentError.PaymentMethod.Type == stripe.PaymentMethodTypeSepaDebit
  337. if !isSEPA {
  338. // Ignore events for other payment methods, since they will be handled
  339. // synchronously
  340. log.Info("Ignoring payment intent failed event for non-SEPA payment method")
  341. return ente.StripeEventLog{}, nil
  342. }
  343. client := c.StripeClients[country]
  344. invoiceID := paymentIntent.Invoice.ID
  345. invoice, err := client.Invoices.Get(invoiceID, nil)
  346. if err != nil {
  347. return ente.StripeEventLog{}, stacktrace.Propagate(err, "")
  348. }
  349. stripeSubscriptionID := invoice.Subscription.ID
  350. currentSubscription, err := c.BillingRepo.GetSubscriptionForTransaction(stripeSubscriptionID, ente.Stripe)
  351. if err != nil {
  352. if errors.Is(err, sql.ErrNoRows) {
  353. // See: Ignore webhooks received before user has been created
  354. log.Warn("Webhook is reporting an event for un-verified subscription stripeSubscriptionID:", stripeSubscriptionID)
  355. }
  356. return ente.StripeEventLog{}, stacktrace.Propagate(err, "")
  357. }
  358. userID := currentSubscription.UserID
  359. stripeSubscription, err := client.Subscriptions.Get(stripeSubscriptionID, nil)
  360. if err != nil {
  361. return ente.StripeEventLog{}, stacktrace.Propagate(err, "")
  362. }
  363. productID := stripeSubscription.Items.Data[0].Price.ID
  364. // If the current subscription is not the same as the one in the webhook,
  365. // then ignore
  366. fmt.Printf("productID: %s, currentSubscription.ProductID: %s\n", productID, currentSubscription.ProductID)
  367. if currentSubscription.ProductID != productID {
  368. // no-op
  369. log.Warn("Webhook is reporting un-verified subscription update", stripeSubscription.ID, "invoiceID:", invoiceID)
  370. return ente.StripeEventLog{UserID: userID, StripeSubscription: *stripeSubscription, Event: event}, nil
  371. }
  372. // If the current subscription is the same as the one in the webhook, then
  373. // we need to expire the subscription, and send an email to the user.
  374. newExpiryTime := time.Now().UnixMicro()
  375. err = c.BillingRepo.UpdateSubscriptionExpiryTime(
  376. currentSubscription.ID, newExpiryTime)
  377. if err != nil {
  378. return ente.StripeEventLog{}, stacktrace.Propagate(err, "")
  379. }
  380. // Send an email to the user
  381. user, err := c.UserRepo.Get(userID)
  382. if err != nil {
  383. return ente.StripeEventLog{}, stacktrace.Propagate(err, "")
  384. }
  385. // TODO: Inform customer that payment_failed.html with invoice.HostedInvoiceURL
  386. err = email.SendTemplatedEmail([]string{user.Email}, "ente", "support@ente.io",
  387. ente.AccountOnHoldEmailSubject, ente.OnHoldTemplate, map[string]interface{}{
  388. "PaymentProvider": "Stripe",
  389. "InvoiceURL": invoice.HostedInvoiceURL,
  390. }, nil)
  391. if err != nil {
  392. return ente.StripeEventLog{}, stacktrace.Propagate(err, "")
  393. }
  394. return ente.StripeEventLog{UserID: userID, StripeSubscription: *stripeSubscription, Event: event}, nil
  395. }
  396. func (c *StripeController) UpdateSubscription(stripeID string, userID int64) (ente.SubscriptionUpdateResponse, error) {
  397. subscription, err := c.BillingRepo.GetUserSubscription(userID)
  398. if err != nil {
  399. return ente.SubscriptionUpdateResponse{}, stacktrace.Propagate(err, "")
  400. }
  401. newPlan, newStripeAccountCountry, err := c.getPlanAndAccount(stripeID)
  402. if err != nil {
  403. return ente.SubscriptionUpdateResponse{}, stacktrace.Propagate(err, "")
  404. }
  405. if subscription.PaymentProvider != ente.Stripe || subscription.ProductID == stripeID || subscription.Attributes.StripeAccountCountry != newStripeAccountCountry {
  406. return ente.SubscriptionUpdateResponse{}, stacktrace.Propagate(ente.ErrBadRequest, "")
  407. }
  408. if newPlan.Storage < subscription.Storage { // Downgrade
  409. canDowngrade, canDowngradeErr := c.CommonBillCtrl.CanDowngradeToGivenStorage(newPlan.Storage, userID)
  410. if canDowngradeErr != nil {
  411. return ente.SubscriptionUpdateResponse{}, stacktrace.Propagate(canDowngradeErr, "")
  412. }
  413. if !canDowngrade {
  414. return ente.SubscriptionUpdateResponse{}, stacktrace.Propagate(ente.ErrCannotDowngrade, "")
  415. }
  416. log.Info("Usage is good")
  417. }
  418. client := c.StripeClients[subscription.Attributes.StripeAccountCountry]
  419. params := stripe.SubscriptionParams{}
  420. params.AddExpand("default_payment_method")
  421. stripeSubscription, err := client.Subscriptions.Get(subscription.OriginalTransactionID, &params)
  422. if err != nil {
  423. return ente.SubscriptionUpdateResponse{}, stacktrace.Propagate(err, "")
  424. }
  425. isSEPA := false
  426. if stripeSubscription.DefaultPaymentMethod != nil {
  427. isSEPA = stripeSubscription.DefaultPaymentMethod.Type == stripe.PaymentMethodTypeSepaDebit
  428. } else {
  429. log.Info("No default payment method found")
  430. }
  431. var paymentBehavior stripe.SubscriptionPaymentBehavior
  432. if isSEPA {
  433. paymentBehavior = stripe.SubscriptionPaymentBehaviorAllowIncomplete
  434. } else {
  435. paymentBehavior = stripe.SubscriptionPaymentBehaviorPendingIfIncomplete
  436. }
  437. params = stripe.SubscriptionParams{
  438. ProrationBehavior: stripe.String(string(stripe.SubscriptionProrationBehaviorAlwaysInvoice)),
  439. Items: []*stripe.SubscriptionItemsParams{
  440. {
  441. ID: stripe.String(stripeSubscription.Items.Data[0].ID),
  442. Price: stripe.String(stripeID),
  443. },
  444. },
  445. PaymentBehavior: stripe.String(string(paymentBehavior)),
  446. }
  447. params.AddExpand("latest_invoice.payment_intent")
  448. newStripeSubscription, err := client.Subscriptions.Update(subscription.OriginalTransactionID, &params)
  449. if err != nil {
  450. stripeError := err.(*stripe.Error)
  451. switch stripeError.Type {
  452. case stripe.ErrorTypeCard:
  453. return ente.SubscriptionUpdateResponse{Status: "requires_payment_method"}, nil
  454. default:
  455. return ente.SubscriptionUpdateResponse{}, stacktrace.Propagate(err, "")
  456. }
  457. }
  458. if isSEPA {
  459. if newStripeSubscription.Status == stripe.SubscriptionStatusPastDue {
  460. if newStripeSubscription.LatestInvoice.PaymentIntent.Status == stripe.PaymentIntentStatusRequiresAction {
  461. return ente.SubscriptionUpdateResponse{Status: "requires_action", ClientSecret: newStripeSubscription.LatestInvoice.PaymentIntent.ClientSecret}, nil
  462. } else if newStripeSubscription.LatestInvoice.PaymentIntent.Status == stripe.PaymentIntentStatusRequiresPaymentMethod {
  463. return ente.SubscriptionUpdateResponse{Status: "requires_payment_method"}, nil
  464. } else if newStripeSubscription.LatestInvoice.PaymentIntent.Status == stripe.PaymentIntentStatusProcessing {
  465. return ente.SubscriptionUpdateResponse{Status: "success"}, nil
  466. }
  467. return ente.SubscriptionUpdateResponse{}, stacktrace.Propagate(ente.ErrBadRequest, "")
  468. }
  469. } else {
  470. if newStripeSubscription.PendingUpdate != nil {
  471. switch newStripeSubscription.LatestInvoice.PaymentIntent.Status {
  472. case stripe.PaymentIntentStatusRequiresAction:
  473. return ente.SubscriptionUpdateResponse{Status: "requires_action", ClientSecret: newStripeSubscription.LatestInvoice.PaymentIntent.ClientSecret}, nil
  474. case stripe.PaymentIntentStatusRequiresPaymentMethod:
  475. inv := newStripeSubscription.LatestInvoice
  476. client.Invoices.VoidInvoice(inv.ID, nil)
  477. return ente.SubscriptionUpdateResponse{Status: "requires_payment_method"}, nil
  478. }
  479. return ente.SubscriptionUpdateResponse{}, stacktrace.Propagate(ente.ErrBadRequest, "")
  480. }
  481. }
  482. return ente.SubscriptionUpdateResponse{Status: "success"}, nil
  483. }
  484. func (c *StripeController) UpdateSubscriptionCancellationStatus(userID int64, status bool) (ente.Subscription, error) {
  485. subscription, err := c.BillingRepo.GetUserSubscription(userID)
  486. if err != nil {
  487. // error sql.ErrNoRows not possible as user must at least have a free subscription
  488. return ente.Subscription{}, stacktrace.Propagate(err, "")
  489. }
  490. if subscription.PaymentProvider != ente.Stripe {
  491. return ente.Subscription{}, stacktrace.Propagate(ente.ErrBadRequest, "")
  492. }
  493. if subscription.Attributes.IsCancelled == status {
  494. // no-op
  495. return subscription, nil
  496. }
  497. client := c.StripeClients[subscription.Attributes.StripeAccountCountry]
  498. params := &stripe.SubscriptionParams{
  499. CancelAtPeriodEnd: stripe.Bool(status),
  500. }
  501. _, err = client.Subscriptions.Update(subscription.OriginalTransactionID, params)
  502. if err != nil {
  503. return ente.Subscription{}, stacktrace.Propagate(err, "")
  504. }
  505. err = c.BillingRepo.UpdateSubscriptionCancellationStatus(userID, status)
  506. if err != nil {
  507. return ente.Subscription{}, stacktrace.Propagate(err, "")
  508. }
  509. subscription.Attributes.IsCancelled = status
  510. return subscription, nil
  511. }
  512. func (c *StripeController) GetStripeCustomerPortal(userID int64, redirectRootURL string) (string, error) {
  513. subscription, err := c.BillingRepo.GetUserSubscription(userID)
  514. if err != nil {
  515. return "", stacktrace.Propagate(err, "")
  516. }
  517. if subscription.PaymentProvider != ente.Stripe {
  518. return "", stacktrace.Propagate(ente.ErrBadRequest, "")
  519. }
  520. client := c.StripeClients[subscription.Attributes.StripeAccountCountry]
  521. params := &stripe.BillingPortalSessionParams{
  522. Customer: stripe.String(subscription.Attributes.CustomerID),
  523. ReturnURL: stripe.String(redirectRootURL),
  524. }
  525. ps, err := client.BillingPortalSessions.New(params)
  526. if err != nil {
  527. return "", stacktrace.Propagate(err, "")
  528. }
  529. return ps.URL, nil
  530. }
  531. func (c *StripeController) getStripeSubscriptionFromSession(userID int64, checkoutSessionID string) (stripe.Subscription, error) {
  532. subscription, err := c.BillingRepo.GetUserSubscription(userID)
  533. if err != nil {
  534. return stripe.Subscription{}, stacktrace.Propagate(err, "")
  535. }
  536. var stripeClient *client.API
  537. if subscription.PaymentProvider == ente.Stripe {
  538. stripeClient = c.StripeClients[subscription.Attributes.StripeAccountCountry]
  539. } else {
  540. stripeClient = c.StripeClients[ente.DefaultStripeAccountCountry]
  541. }
  542. params := &stripe.CheckoutSessionParams{}
  543. params.AddExpand("subscription")
  544. checkoutSession, err := stripeClient.CheckoutSessions.Get(checkoutSessionID, params)
  545. if err != nil {
  546. return stripe.Subscription{}, stacktrace.Propagate(err, "")
  547. }
  548. if (*checkoutSession.Subscription).Status != stripe.SubscriptionStatusActive {
  549. return stripe.Subscription{}, stacktrace.Propagate(&stripe.InvalidRequestError{}, "")
  550. }
  551. return *checkoutSession.Subscription, nil
  552. }
  553. func (c *StripeController) getPriceIDFromSession(sessionID string) (string, error) {
  554. stripeClient := c.StripeClients[ente.DefaultStripeAccountCountry]
  555. params := &stripe.CheckoutSessionListLineItemsParams{}
  556. params.AddExpand("data.price")
  557. items := stripeClient.CheckoutSessions.ListLineItems(sessionID, params)
  558. for items.Next() { // Return the first PriceID that has been fetched
  559. return items.LineItem().Price.ID, nil
  560. }
  561. return "", stacktrace.Propagate(ente.ErrNotFound, "")
  562. }
  563. func (c *StripeController) getUserStripeSubscription(userID int64) (stripe.Subscription, error) {
  564. subscription, err := c.BillingRepo.GetUserSubscription(userID)
  565. if err != nil {
  566. return stripe.Subscription{}, stacktrace.Propagate(err, "")
  567. }
  568. if subscription.PaymentProvider != ente.Stripe {
  569. return stripe.Subscription{}, stacktrace.Propagate(ente.ErrCannotSwitchPaymentProvider, "")
  570. }
  571. client := c.StripeClients[subscription.Attributes.StripeAccountCountry]
  572. stripeSubscription, err := client.Subscriptions.Get(subscription.OriginalTransactionID, nil)
  573. if err != nil {
  574. return stripe.Subscription{}, stacktrace.Propagate(err, "")
  575. }
  576. return *stripeSubscription, nil
  577. }
  578. func (c *StripeController) getPlanAndAccount(stripeID string) (ente.BillingPlan, ente.StripeAccountCountry, error) {
  579. for stripeAccountCountry, billingPlansCountryWise := range c.BillingPlansPerAccount {
  580. for _, plans := range billingPlansCountryWise {
  581. for _, plan := range plans {
  582. if plan.StripeID == stripeID {
  583. return plan, stripeAccountCountry, nil
  584. }
  585. }
  586. }
  587. }
  588. return ente.BillingPlan{}, "", stacktrace.Propagate(ente.ErrNotFound, "")
  589. }
  590. func (c *StripeController) getEnteSubscriptionFromStripeSubscription(userID int64, stripeSubscription stripe.Subscription) (ente.Subscription, error) {
  591. productID := stripeSubscription.Items.Data[0].Price.ID
  592. plan, stripeAccountCountry, err := c.getPlanAndAccount(productID)
  593. if err != nil {
  594. return ente.Subscription{}, stacktrace.Propagate(err, "")
  595. }
  596. s := ente.Subscription{
  597. UserID: userID,
  598. PaymentProvider: ente.Stripe,
  599. ProductID: productID,
  600. Storage: plan.Storage,
  601. Attributes: ente.SubscriptionAttributes{CustomerID: stripeSubscription.Customer.ID, IsCancelled: false, StripeAccountCountry: stripeAccountCountry},
  602. OriginalTransactionID: stripeSubscription.ID,
  603. ExpiryTime: stripeSubscription.CurrentPeriodEnd * 1000 * 1000,
  604. }
  605. return s, nil
  606. }
  607. func (c *StripeController) UpdateBillingEmail(subscription ente.Subscription, newEmail string) error {
  608. params := &stripe.CustomerParams{Email: &newEmail}
  609. client := c.StripeClients[subscription.Attributes.StripeAccountCountry]
  610. _, err := client.Customers.Update(
  611. subscription.Attributes.CustomerID,
  612. params,
  613. )
  614. if err != nil {
  615. return stacktrace.Propagate(err, "failed to update stripe customer emailID")
  616. }
  617. return nil
  618. }
  619. func (c *StripeController) CancelSubAndDeleteCustomer(subscription ente.Subscription, logger *log.Entry) error {
  620. client := c.StripeClients[subscription.Attributes.StripeAccountCountry]
  621. if !subscription.Attributes.IsCancelled {
  622. prorateRefund := true
  623. logger.Info("cancelling sub with prorated refund")
  624. _, err := client.Subscriptions.Update(subscription.OriginalTransactionID, nil)
  625. if err != nil {
  626. stripeError := err.(*stripe.Error)
  627. errorMsg := fmt.Sprintf("subscription updation failed during account deletion: %s, %s", stripeError.Msg, stripeError.Code)
  628. log.Error(errorMsg)
  629. c.DiscordController.Notify(errorMsg)
  630. if stripeError.HTTPStatusCode == http.StatusNotFound {
  631. log.Error("Ignoring error since an active subscription could not be found")
  632. return nil
  633. } else if stripeError.HTTPStatusCode == http.StatusBadRequest {
  634. log.Error("Bad request while trying to delete account")
  635. return nil
  636. }
  637. return stacktrace.Propagate(err, "")
  638. }
  639. _, err = client.Subscriptions.Cancel(subscription.OriginalTransactionID, &stripe.SubscriptionCancelParams{
  640. Prorate: &prorateRefund,
  641. })
  642. if err != nil {
  643. stripeError := err.(*stripe.Error)
  644. logger.Error(fmt.Sprintf("subscription cancel failed msg= %s for userID=%d"+stripeError.Msg, subscription.UserID))
  645. // ignore if subscription doesn't exist, already deleted
  646. if stripeError.HTTPStatusCode != 404 {
  647. return stacktrace.Propagate(err, "")
  648. }
  649. }
  650. err = c.BillingRepo.UpdateSubscriptionCancellationStatus(subscription.UserID, true)
  651. if err != nil {
  652. return stacktrace.Propagate(err, "")
  653. }
  654. }
  655. logger.Info("deleting customer from stripe")
  656. _, err := client.Customers.Del(
  657. subscription.Attributes.CustomerID,
  658. &stripe.CustomerParams{},
  659. )
  660. if err != nil {
  661. stripeError := err.(*stripe.Error)
  662. switch stripeError.Type {
  663. case stripe.ErrorTypeInvalidRequest:
  664. if stripe.ErrorCodeResourceMissing == stripeError.Code {
  665. return nil
  666. }
  667. return stacktrace.Propagate(err, fmt.Sprintf("failed to delete customer %s", subscription.Attributes.CustomerID))
  668. default:
  669. return stacktrace.Propagate(err, fmt.Sprintf("failed to delete customer %s", subscription.Attributes.CustomerID))
  670. }
  671. }
  672. return nil
  673. }