token_provider.go 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. package imds
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "net/http"
  7. "sync"
  8. "sync/atomic"
  9. "time"
  10. smithy "github.com/aws/smithy-go"
  11. "github.com/aws/smithy-go/middleware"
  12. smithyhttp "github.com/aws/smithy-go/transport/http"
  13. )
  14. const (
  15. // Headers for Token and TTL
  16. tokenHeader = "x-aws-ec2-metadata-token"
  17. defaultTokenTTL = 5 * time.Minute
  18. )
  19. type tokenProvider struct {
  20. client *Client
  21. tokenTTL time.Duration
  22. token *apiToken
  23. tokenMux sync.RWMutex
  24. disabled uint32 // Atomic updated
  25. }
  26. func newTokenProvider(client *Client, ttl time.Duration) *tokenProvider {
  27. return &tokenProvider{
  28. client: client,
  29. tokenTTL: ttl,
  30. }
  31. }
  32. // apiToken provides the API token used by all operation calls for th EC2
  33. // Instance metadata service.
  34. type apiToken struct {
  35. token string
  36. expires time.Time
  37. }
  38. var timeNow = time.Now
  39. // Expired returns if the token is expired.
  40. func (t *apiToken) Expired() bool {
  41. // Calling Round(0) on the current time will truncate the monotonic reading only. Ensures credential expiry
  42. // time is always based on reported wall-clock time.
  43. return timeNow().Round(0).After(t.expires)
  44. }
  45. func (t *tokenProvider) ID() string { return "APITokenProvider" }
  46. // HandleFinalize is the finalize stack middleware, that if the token provider is
  47. // enabled, will attempt to add the cached API token to the request. If the API
  48. // token is not cached, it will be retrieved in a separate API call, getToken.
  49. //
  50. // For retry attempts, handler must be added after attempt retryer.
  51. //
  52. // If request for getToken fails the token provider may be disabled from future
  53. // requests, depending on the response status code.
  54. func (t *tokenProvider) HandleFinalize(
  55. ctx context.Context, input middleware.FinalizeInput, next middleware.FinalizeHandler,
  56. ) (
  57. out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
  58. ) {
  59. if !t.enabled() {
  60. // short-circuits to insecure data flow if token provider is disabled.
  61. return next.HandleFinalize(ctx, input)
  62. }
  63. req, ok := input.Request.(*smithyhttp.Request)
  64. if !ok {
  65. return out, metadata, fmt.Errorf("unexpected transport request type %T", input.Request)
  66. }
  67. tok, err := t.getToken(ctx)
  68. if err != nil {
  69. // If the error allows the token to downgrade to insecure flow allow that.
  70. var bypassErr *bypassTokenRetrievalError
  71. if errors.As(err, &bypassErr) {
  72. return next.HandleFinalize(ctx, input)
  73. }
  74. return out, metadata, fmt.Errorf("failed to get API token, %w", err)
  75. }
  76. req.Header.Set(tokenHeader, tok.token)
  77. return next.HandleFinalize(ctx, input)
  78. }
  79. // HandleDeserialize is the deserialize stack middleware for determining if the
  80. // operation the token provider is decorating failed because of a 401
  81. // unauthorized status code. If the operation failed for that reason the token
  82. // provider needs to be re-enabled so that it can start adding the API token to
  83. // operation calls.
  84. func (t *tokenProvider) HandleDeserialize(
  85. ctx context.Context, input middleware.DeserializeInput, next middleware.DeserializeHandler,
  86. ) (
  87. out middleware.DeserializeOutput, metadata middleware.Metadata, err error,
  88. ) {
  89. out, metadata, err = next.HandleDeserialize(ctx, input)
  90. if err == nil {
  91. return out, metadata, err
  92. }
  93. resp, ok := out.RawResponse.(*smithyhttp.Response)
  94. if !ok {
  95. return out, metadata, fmt.Errorf("expect HTTP transport, got %T", out.RawResponse)
  96. }
  97. if resp.StatusCode == http.StatusUnauthorized { // unauthorized
  98. err = &retryableError{Err: err}
  99. t.enable()
  100. }
  101. return out, metadata, err
  102. }
  103. type retryableError struct {
  104. Err error
  105. }
  106. func (*retryableError) RetryableError() bool { return true }
  107. func (e *retryableError) Error() string { return e.Err.Error() }
  108. func (t *tokenProvider) getToken(ctx context.Context) (tok *apiToken, err error) {
  109. if !t.enabled() {
  110. return nil, &bypassTokenRetrievalError{
  111. Err: fmt.Errorf("cannot get API token, provider disabled"),
  112. }
  113. }
  114. t.tokenMux.RLock()
  115. tok = t.token
  116. t.tokenMux.RUnlock()
  117. if tok != nil && !tok.Expired() {
  118. return tok, nil
  119. }
  120. tok, err = t.updateToken(ctx)
  121. if err != nil {
  122. return nil, fmt.Errorf("cannot get API token, %w", err)
  123. }
  124. return tok, nil
  125. }
  126. func (t *tokenProvider) updateToken(ctx context.Context) (*apiToken, error) {
  127. t.tokenMux.Lock()
  128. defer t.tokenMux.Unlock()
  129. // Prevent multiple requests to update retrieving the token.
  130. if t.token != nil && !t.token.Expired() {
  131. tok := t.token
  132. return tok, nil
  133. }
  134. result, err := t.client.getToken(ctx, &getTokenInput{
  135. TokenTTL: t.tokenTTL,
  136. })
  137. if err != nil {
  138. // change the disabled flag on token provider to true, when error is request timeout error.
  139. var statusErr interface{ HTTPStatusCode() int }
  140. if errors.As(err, &statusErr) {
  141. switch statusErr.HTTPStatusCode() {
  142. // Disable get token if failed because of 403, 404, or 405
  143. case http.StatusForbidden,
  144. http.StatusNotFound,
  145. http.StatusMethodNotAllowed:
  146. t.disable()
  147. // 400 errors are terminal, and need to be upstreamed
  148. case http.StatusBadRequest:
  149. return nil, err
  150. }
  151. }
  152. // Disable if request send failed or timed out getting response
  153. var re *smithyhttp.RequestSendError
  154. var ce *smithy.CanceledError
  155. if errors.As(err, &re) || errors.As(err, &ce) {
  156. atomic.StoreUint32(&t.disabled, 1)
  157. }
  158. // Token couldn't be retrieved, but bypass this, and allow the
  159. // request to continue.
  160. return nil, &bypassTokenRetrievalError{Err: err}
  161. }
  162. tok := &apiToken{
  163. token: result.Token,
  164. expires: timeNow().Add(result.TokenTTL),
  165. }
  166. t.token = tok
  167. return tok, nil
  168. }
  169. type bypassTokenRetrievalError struct {
  170. Err error
  171. }
  172. func (e *bypassTokenRetrievalError) Error() string {
  173. return fmt.Sprintf("bypass token retrieval, %v", e.Err)
  174. }
  175. func (e *bypassTokenRetrievalError) Unwrap() error { return e.Err }
  176. // enabled returns if the token provider is current enabled or not.
  177. func (t *tokenProvider) enabled() bool {
  178. return atomic.LoadUint32(&t.disabled) == 0
  179. }
  180. // disable disables the token provider and it will no longer attempt to inject
  181. // the token, nor request updates.
  182. func (t *tokenProvider) disable() {
  183. atomic.StoreUint32(&t.disabled, 1)
  184. }
  185. // enable enables the token provide to start refreshing tokens, and adding them
  186. // to the pending request.
  187. func (t *tokenProvider) enable() {
  188. t.tokenMux.Lock()
  189. t.token = nil
  190. t.tokenMux.Unlock()
  191. atomic.StoreUint32(&t.disabled, 0)
  192. }