appstore.go 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. package controller
  2. import (
  3. "context"
  4. "fmt"
  5. "github.com/ente-io/museum/pkg/controller/commonbilling"
  6. "github.com/prometheus/common/log"
  7. "strconv"
  8. "strings"
  9. "github.com/ente-io/stacktrace"
  10. "github.com/gin-contrib/requestid"
  11. "github.com/gin-gonic/gin"
  12. "github.com/sirupsen/logrus"
  13. "github.com/spf13/viper"
  14. "github.com/awa/go-iap/appstore"
  15. "github.com/ente-io/museum/ente"
  16. "github.com/ente-io/museum/pkg/repo"
  17. "github.com/ente-io/museum/pkg/utils/array"
  18. )
  19. // AppStoreController provides abstractions for handling billing on AppStore
  20. type AppStoreController struct {
  21. AppStoreClient appstore.Client
  22. BillingRepo *repo.BillingRepository
  23. FileRepo *repo.FileRepository
  24. UserRepo *repo.UserRepository
  25. BillingPlansPerCountry ente.BillingPlansPerCountry
  26. CommonBillCtrl *commonbilling.Controller
  27. // appStoreSharedPassword is the password to be used to access AppStore APIs
  28. appStoreSharedPassword string
  29. }
  30. // Return a new instance of AppStoreController
  31. func NewAppStoreController(
  32. plans ente.BillingPlansPerCountry,
  33. billingRepo *repo.BillingRepository,
  34. fileRepo *repo.FileRepository,
  35. userRepo *repo.UserRepository,
  36. commonBillCtrl *commonbilling.Controller,
  37. ) *AppStoreController {
  38. appleSharedSecret := viper.GetString("apple.shared-secret")
  39. return &AppStoreController{
  40. AppStoreClient: *appstore.New(),
  41. BillingRepo: billingRepo,
  42. FileRepo: fileRepo,
  43. UserRepo: userRepo,
  44. BillingPlansPerCountry: plans,
  45. appStoreSharedPassword: appleSharedSecret,
  46. CommonBillCtrl: commonBillCtrl,
  47. }
  48. }
  49. var SubsUpdateNotificationTypes = []string{string(appstore.NotificationTypeDidChangeRenewalStatus), string(appstore.NotificationTypeCancel), string(appstore.NotificationTypeDidRevoke)}
  50. // HandleNotification handles an AppStore notification
  51. func (c *AppStoreController) HandleNotification(ctx *gin.Context, notification appstore.SubscriptionNotification) error {
  52. logger := logrus.WithFields(logrus.Fields{
  53. "req_id": requestid.Get(ctx),
  54. })
  55. purchase, err := c.verifyAppStoreSubscription(notification.UnifiedReceipt.LatestReceipt)
  56. if err != nil {
  57. return stacktrace.Propagate(err, "")
  58. }
  59. latestReceiptInfo := c.getLatestReceiptInfo(purchase.LatestReceiptInfo)
  60. if latestReceiptInfo.TransactionID == latestReceiptInfo.OriginalTransactionID && !array.StringInList(string(notification.NotificationType), SubsUpdateNotificationTypes) {
  61. var logMsg = fmt.Sprintf("Ignoring notification of type %s", notification.NotificationType)
  62. if notification.NotificationType != appstore.NotificationTypeInitialBuy {
  63. // log unexpected notification types
  64. logger.Error(logMsg)
  65. } else {
  66. logger.Info(logMsg)
  67. }
  68. // First subscription, no user to link to
  69. return nil
  70. }
  71. subscription, err := c.BillingRepo.GetSubscriptionForTransaction(latestReceiptInfo.OriginalTransactionID, ente.AppStore)
  72. if err != nil {
  73. return stacktrace.Propagate(err, "")
  74. }
  75. expiryTimeInMillis, _ := strconv.ParseInt(latestReceiptInfo.ExpiresDate.ExpiresDateMS, 10, 64)
  76. if latestReceiptInfo.ProductID == subscription.ProductID && expiryTimeInMillis*1000 < subscription.ExpiryTime {
  77. // Outdated notification, no-op
  78. } else {
  79. if latestReceiptInfo.ProductID != subscription.ProductID {
  80. var newPlan ente.BillingPlan
  81. plans := c.BillingPlansPerCountry["EU"] // Country code is irrelevant since Storage will be the same for a given subscriptionID
  82. for _, plan := range plans {
  83. if plan.IOSID == latestReceiptInfo.ProductID {
  84. newPlan = plan
  85. break
  86. }
  87. }
  88. if newPlan.Storage < subscription.Storage { // Downgrade
  89. canDowngrade, canDowngradeErr := c.CommonBillCtrl.CanDowngradeToGivenStorage(newPlan.Storage, subscription.UserID)
  90. if canDowngradeErr != nil {
  91. return stacktrace.Propagate(canDowngradeErr, "")
  92. }
  93. if !canDowngrade {
  94. return stacktrace.Propagate(ente.ErrCannotDowngrade, "")
  95. }
  96. log.Info("Usage is good")
  97. }
  98. newSubscription := ente.Subscription{
  99. Storage: newPlan.Storage,
  100. ExpiryTime: expiryTimeInMillis * 1000,
  101. ProductID: latestReceiptInfo.ProductID,
  102. PaymentProvider: ente.AppStore,
  103. OriginalTransactionID: latestReceiptInfo.OriginalTransactionID,
  104. Attributes: ente.SubscriptionAttributes{LatestVerificationData: notification.UnifiedReceipt.LatestReceipt},
  105. }
  106. err = c.BillingRepo.ReplaceSubscription(
  107. subscription.ID,
  108. newSubscription,
  109. )
  110. if err != nil {
  111. return stacktrace.Propagate(err, "")
  112. }
  113. } else {
  114. if notification.NotificationType == appstore.NotificationTypeDidChangeRenewalStatus {
  115. err := c.BillingRepo.UpdateSubscriptionCancellationStatus(subscription.UserID, notification.AutoRenewStatus == "false")
  116. if err != nil {
  117. return stacktrace.Propagate(err, "")
  118. }
  119. } else if notification.NotificationType == appstore.NotificationTypeCancel || notification.NotificationType == appstore.NotificationTypeDidRevoke {
  120. err := c.BillingRepo.UpdateSubscriptionCancellationStatus(subscription.UserID, true)
  121. if err != nil {
  122. return stacktrace.Propagate(err, "")
  123. }
  124. }
  125. err = c.BillingRepo.UpdateSubscriptionExpiryTime(subscription.ID, expiryTimeInMillis*1000)
  126. if err != nil {
  127. return stacktrace.Propagate(err, "")
  128. }
  129. }
  130. }
  131. err = c.BillingRepo.LogAppStorePush(subscription.UserID, notification, *purchase)
  132. return stacktrace.Propagate(err, "")
  133. }
  134. // GetVerifiedSubscription verifies and returns the verified subscription
  135. func (c *AppStoreController) GetVerifiedSubscription(userID int64, productID string, verificationData string) (ente.Subscription, error) {
  136. var s ente.Subscription
  137. s.UserID = userID
  138. s.ProductID = productID
  139. s.PaymentProvider = ente.AppStore
  140. s.Attributes.LatestVerificationData = verificationData
  141. plans := c.BillingPlansPerCountry["EU"] // Country code is irrelevant since Storage will be the same for a given subscriptionID
  142. response, err := c.verifyAppStoreSubscription(verificationData)
  143. if err != nil {
  144. return ente.Subscription{}, stacktrace.Propagate(err, "")
  145. }
  146. for _, plan := range plans {
  147. if plan.IOSID == productID {
  148. s.Storage = plan.Storage
  149. break
  150. }
  151. }
  152. latestReceiptInfo := c.getLatestReceiptInfo(response.LatestReceiptInfo)
  153. s.OriginalTransactionID = latestReceiptInfo.OriginalTransactionID
  154. expiryTime, _ := strconv.ParseInt(latestReceiptInfo.ExpiresDate.ExpiresDateMS, 10, 64)
  155. s.ExpiryTime = expiryTime * 1000
  156. return s, nil
  157. }
  158. // VerifyAppStoreSubscription verifies an AppStore subscription
  159. func (c *AppStoreController) verifyAppStoreSubscription(verificationData string) (*appstore.IAPResponse, error) {
  160. iapRequest := appstore.IAPRequest{
  161. ReceiptData: verificationData,
  162. Password: c.appStoreSharedPassword,
  163. }
  164. response := &appstore.IAPResponse{}
  165. context := context.Background()
  166. err := c.AppStoreClient.Verify(context, iapRequest, response)
  167. if err != nil {
  168. return nil, stacktrace.Propagate(err, "")
  169. }
  170. if response.Status != 0 {
  171. return nil, ente.ErrBadRequest
  172. }
  173. return response, nil
  174. }
  175. func (c *AppStoreController) getLatestReceiptInfo(receiptInfo []appstore.InApp) appstore.InApp {
  176. latestReceiptInfo := receiptInfo[0]
  177. for _, receiptInfo := range receiptInfo {
  178. if strings.Compare(latestReceiptInfo.ExpiresDate.ExpiresDateMS, receiptInfo.ExpiresDate.ExpiresDateMS) < 0 {
  179. latestReceiptInfo = receiptInfo
  180. }
  181. }
  182. return latestReceiptInfo
  183. }