sso_cached_token.go 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. package ssocreds
  2. import (
  3. "crypto/sha1"
  4. "encoding/hex"
  5. "encoding/json"
  6. "fmt"
  7. "io/ioutil"
  8. "os"
  9. "path/filepath"
  10. "strconv"
  11. "strings"
  12. "time"
  13. "github.com/aws/aws-sdk-go-v2/internal/sdk"
  14. "github.com/aws/aws-sdk-go-v2/internal/shareddefaults"
  15. )
  16. var osUserHomeDur = shareddefaults.UserHomeDir
  17. // StandardCachedTokenFilepath returns the filepath for the cached SSO token file, or
  18. // error if unable get derive the path. Key that will be used to compute a SHA1
  19. // value that is hex encoded.
  20. //
  21. // Derives the filepath using the Key as:
  22. //
  23. // ~/.aws/sso/cache/<sha1-hex-encoded-key>.json
  24. func StandardCachedTokenFilepath(key string) (string, error) {
  25. homeDir := osUserHomeDur()
  26. if len(homeDir) == 0 {
  27. return "", fmt.Errorf("unable to get USER's home directory for cached token")
  28. }
  29. hash := sha1.New()
  30. if _, err := hash.Write([]byte(key)); err != nil {
  31. return "", fmt.Errorf("unable to compute cached token filepath key SHA1 hash, %w", err)
  32. }
  33. cacheFilename := strings.ToLower(hex.EncodeToString(hash.Sum(nil))) + ".json"
  34. return filepath.Join(homeDir, ".aws", "sso", "cache", cacheFilename), nil
  35. }
  36. type tokenKnownFields struct {
  37. AccessToken string `json:"accessToken,omitempty"`
  38. ExpiresAt *rfc3339 `json:"expiresAt,omitempty"`
  39. RefreshToken string `json:"refreshToken,omitempty"`
  40. ClientID string `json:"clientId,omitempty"`
  41. ClientSecret string `json:"clientSecret,omitempty"`
  42. }
  43. type token struct {
  44. tokenKnownFields
  45. UnknownFields map[string]interface{} `json:"-"`
  46. }
  47. func (t token) MarshalJSON() ([]byte, error) {
  48. fields := map[string]interface{}{}
  49. setTokenFieldString(fields, "accessToken", t.AccessToken)
  50. setTokenFieldRFC3339(fields, "expiresAt", t.ExpiresAt)
  51. setTokenFieldString(fields, "refreshToken", t.RefreshToken)
  52. setTokenFieldString(fields, "clientId", t.ClientID)
  53. setTokenFieldString(fields, "clientSecret", t.ClientSecret)
  54. for k, v := range t.UnknownFields {
  55. if _, ok := fields[k]; ok {
  56. return nil, fmt.Errorf("unknown token field %v, duplicates known field", k)
  57. }
  58. fields[k] = v
  59. }
  60. return json.Marshal(fields)
  61. }
  62. func setTokenFieldString(fields map[string]interface{}, key, value string) {
  63. if value == "" {
  64. return
  65. }
  66. fields[key] = value
  67. }
  68. func setTokenFieldRFC3339(fields map[string]interface{}, key string, value *rfc3339) {
  69. if value == nil {
  70. return
  71. }
  72. fields[key] = value
  73. }
  74. func (t *token) UnmarshalJSON(b []byte) error {
  75. var fields map[string]interface{}
  76. if err := json.Unmarshal(b, &fields); err != nil {
  77. return nil
  78. }
  79. t.UnknownFields = map[string]interface{}{}
  80. for k, v := range fields {
  81. var err error
  82. switch k {
  83. case "accessToken":
  84. err = getTokenFieldString(v, &t.AccessToken)
  85. case "expiresAt":
  86. err = getTokenFieldRFC3339(v, &t.ExpiresAt)
  87. case "refreshToken":
  88. err = getTokenFieldString(v, &t.RefreshToken)
  89. case "clientId":
  90. err = getTokenFieldString(v, &t.ClientID)
  91. case "clientSecret":
  92. err = getTokenFieldString(v, &t.ClientSecret)
  93. default:
  94. t.UnknownFields[k] = v
  95. }
  96. if err != nil {
  97. return fmt.Errorf("field %q, %w", k, err)
  98. }
  99. }
  100. return nil
  101. }
  102. func getTokenFieldString(v interface{}, value *string) error {
  103. var ok bool
  104. *value, ok = v.(string)
  105. if !ok {
  106. return fmt.Errorf("expect value to be string, got %T", v)
  107. }
  108. return nil
  109. }
  110. func getTokenFieldRFC3339(v interface{}, value **rfc3339) error {
  111. var stringValue string
  112. if err := getTokenFieldString(v, &stringValue); err != nil {
  113. return err
  114. }
  115. timeValue, err := parseRFC3339(stringValue)
  116. if err != nil {
  117. return err
  118. }
  119. *value = &timeValue
  120. return nil
  121. }
  122. func loadCachedToken(filename string) (token, error) {
  123. fileBytes, err := ioutil.ReadFile(filename)
  124. if err != nil {
  125. return token{}, fmt.Errorf("failed to read cached SSO token file, %w", err)
  126. }
  127. var t token
  128. if err := json.Unmarshal(fileBytes, &t); err != nil {
  129. return token{}, fmt.Errorf("failed to parse cached SSO token file, %w", err)
  130. }
  131. if len(t.AccessToken) == 0 || t.ExpiresAt == nil || time.Time(*t.ExpiresAt).IsZero() {
  132. return token{}, fmt.Errorf(
  133. "cached SSO token must contain accessToken and expiresAt fields")
  134. }
  135. return t, nil
  136. }
  137. func storeCachedToken(filename string, t token, fileMode os.FileMode) (err error) {
  138. tmpFilename := filename + ".tmp-" + strconv.FormatInt(sdk.NowTime().UnixNano(), 10)
  139. if err := writeCacheFile(tmpFilename, fileMode, t); err != nil {
  140. return err
  141. }
  142. if err := os.Rename(tmpFilename, filename); err != nil {
  143. return fmt.Errorf("failed to replace old cached SSO token file, %w", err)
  144. }
  145. return nil
  146. }
  147. func writeCacheFile(filename string, fileMode os.FileMode, t token) (err error) {
  148. var f *os.File
  149. f, err = os.OpenFile(filename, os.O_CREATE|os.O_TRUNC|os.O_RDWR, fileMode)
  150. if err != nil {
  151. return fmt.Errorf("failed to create cached SSO token file %w", err)
  152. }
  153. defer func() {
  154. closeErr := f.Close()
  155. if err == nil && closeErr != nil {
  156. err = fmt.Errorf("failed to close cached SSO token file, %w", closeErr)
  157. }
  158. }()
  159. encoder := json.NewEncoder(f)
  160. if err = encoder.Encode(t); err != nil {
  161. return fmt.Errorf("failed to serialize cached SSO token, %w", err)
  162. }
  163. return nil
  164. }
  165. type rfc3339 time.Time
  166. func parseRFC3339(v string) (rfc3339, error) {
  167. parsed, err := time.Parse(time.RFC3339, v)
  168. if err != nil {
  169. return rfc3339{}, fmt.Errorf("expected RFC3339 timestamp: %w", err)
  170. }
  171. return rfc3339(parsed), nil
  172. }
  173. func (r *rfc3339) UnmarshalJSON(bytes []byte) (err error) {
  174. var value string
  175. // Use JSON unmarshal to unescape the quoted value making use of JSON's
  176. // unquoting rules.
  177. if err = json.Unmarshal(bytes, &value); err != nil {
  178. return err
  179. }
  180. *r, err = parseRFC3339(value)
  181. return nil
  182. }
  183. func (r *rfc3339) MarshalJSON() ([]byte, error) {
  184. value := time.Time(*r).Format(time.RFC3339)
  185. // Use JSON unmarshal to unescape the quoted value making use of JSON's
  186. // quoting rules.
  187. return json.Marshal(value)
  188. }