123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237 |
- package imds
- import (
- "context"
- "errors"
- "fmt"
- "net/http"
- "sync"
- "sync/atomic"
- "time"
- smithy "github.com/aws/smithy-go"
- "github.com/aws/smithy-go/middleware"
- smithyhttp "github.com/aws/smithy-go/transport/http"
- )
- const (
- // Headers for Token and TTL
- tokenHeader = "x-aws-ec2-metadata-token"
- defaultTokenTTL = 5 * time.Minute
- )
- type tokenProvider struct {
- client *Client
- tokenTTL time.Duration
- token *apiToken
- tokenMux sync.RWMutex
- disabled uint32 // Atomic updated
- }
- func newTokenProvider(client *Client, ttl time.Duration) *tokenProvider {
- return &tokenProvider{
- client: client,
- tokenTTL: ttl,
- }
- }
- // apiToken provides the API token used by all operation calls for th EC2
- // Instance metadata service.
- type apiToken struct {
- token string
- expires time.Time
- }
- var timeNow = time.Now
- // Expired returns if the token is expired.
- func (t *apiToken) Expired() bool {
- // Calling Round(0) on the current time will truncate the monotonic reading only. Ensures credential expiry
- // time is always based on reported wall-clock time.
- return timeNow().Round(0).After(t.expires)
- }
- func (t *tokenProvider) ID() string { return "APITokenProvider" }
- // HandleFinalize is the finalize stack middleware, that if the token provider is
- // enabled, will attempt to add the cached API token to the request. If the API
- // token is not cached, it will be retrieved in a separate API call, getToken.
- //
- // For retry attempts, handler must be added after attempt retryer.
- //
- // If request for getToken fails the token provider may be disabled from future
- // requests, depending on the response status code.
- func (t *tokenProvider) HandleFinalize(
- ctx context.Context, input middleware.FinalizeInput, next middleware.FinalizeHandler,
- ) (
- out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
- ) {
- if !t.enabled() {
- // short-circuits to insecure data flow if token provider is disabled.
- return next.HandleFinalize(ctx, input)
- }
- req, ok := input.Request.(*smithyhttp.Request)
- if !ok {
- return out, metadata, fmt.Errorf("unexpected transport request type %T", input.Request)
- }
- tok, err := t.getToken(ctx)
- if err != nil {
- // If the error allows the token to downgrade to insecure flow allow that.
- var bypassErr *bypassTokenRetrievalError
- if errors.As(err, &bypassErr) {
- return next.HandleFinalize(ctx, input)
- }
- return out, metadata, fmt.Errorf("failed to get API token, %w", err)
- }
- req.Header.Set(tokenHeader, tok.token)
- return next.HandleFinalize(ctx, input)
- }
- // HandleDeserialize is the deserialize stack middleware for determining if the
- // operation the token provider is decorating failed because of a 401
- // unauthorized status code. If the operation failed for that reason the token
- // provider needs to be re-enabled so that it can start adding the API token to
- // operation calls.
- func (t *tokenProvider) HandleDeserialize(
- ctx context.Context, input middleware.DeserializeInput, next middleware.DeserializeHandler,
- ) (
- out middleware.DeserializeOutput, metadata middleware.Metadata, err error,
- ) {
- out, metadata, err = next.HandleDeserialize(ctx, input)
- if err == nil {
- return out, metadata, err
- }
- resp, ok := out.RawResponse.(*smithyhttp.Response)
- if !ok {
- return out, metadata, fmt.Errorf("expect HTTP transport, got %T", out.RawResponse)
- }
- if resp.StatusCode == http.StatusUnauthorized { // unauthorized
- err = &retryableError{Err: err}
- t.enable()
- }
- return out, metadata, err
- }
- type retryableError struct {
- Err error
- }
- func (*retryableError) RetryableError() bool { return true }
- func (e *retryableError) Error() string { return e.Err.Error() }
- func (t *tokenProvider) getToken(ctx context.Context) (tok *apiToken, err error) {
- if !t.enabled() {
- return nil, &bypassTokenRetrievalError{
- Err: fmt.Errorf("cannot get API token, provider disabled"),
- }
- }
- t.tokenMux.RLock()
- tok = t.token
- t.tokenMux.RUnlock()
- if tok != nil && !tok.Expired() {
- return tok, nil
- }
- tok, err = t.updateToken(ctx)
- if err != nil {
- return nil, fmt.Errorf("cannot get API token, %w", err)
- }
- return tok, nil
- }
- func (t *tokenProvider) updateToken(ctx context.Context) (*apiToken, error) {
- t.tokenMux.Lock()
- defer t.tokenMux.Unlock()
- // Prevent multiple requests to update retrieving the token.
- if t.token != nil && !t.token.Expired() {
- tok := t.token
- return tok, nil
- }
- result, err := t.client.getToken(ctx, &getTokenInput{
- TokenTTL: t.tokenTTL,
- })
- if err != nil {
- // change the disabled flag on token provider to true, when error is request timeout error.
- var statusErr interface{ HTTPStatusCode() int }
- if errors.As(err, &statusErr) {
- switch statusErr.HTTPStatusCode() {
- // Disable get token if failed because of 403, 404, or 405
- case http.StatusForbidden,
- http.StatusNotFound,
- http.StatusMethodNotAllowed:
- t.disable()
- // 400 errors are terminal, and need to be upstreamed
- case http.StatusBadRequest:
- return nil, err
- }
- }
- // Disable if request send failed or timed out getting response
- var re *smithyhttp.RequestSendError
- var ce *smithy.CanceledError
- if errors.As(err, &re) || errors.As(err, &ce) {
- atomic.StoreUint32(&t.disabled, 1)
- }
- // Token couldn't be retrieved, but bypass this, and allow the
- // request to continue.
- return nil, &bypassTokenRetrievalError{Err: err}
- }
- tok := &apiToken{
- token: result.Token,
- expires: timeNow().Add(result.TokenTTL),
- }
- t.token = tok
- return tok, nil
- }
- type bypassTokenRetrievalError struct {
- Err error
- }
- func (e *bypassTokenRetrievalError) Error() string {
- return fmt.Sprintf("bypass token retrieval, %v", e.Err)
- }
- func (e *bypassTokenRetrievalError) Unwrap() error { return e.Err }
- // enabled returns if the token provider is current enabled or not.
- func (t *tokenProvider) enabled() bool {
- return atomic.LoadUint32(&t.disabled) == 0
- }
- // disable disables the token provider and it will no longer attempt to inject
- // the token, nor request updates.
- func (t *tokenProvider) disable() {
- atomic.StoreUint32(&t.disabled, 1)
- }
- // enable enables the token provide to start refreshing tokens, and adding them
- // to the pending request.
- func (t *tokenProvider) enable() {
- t.tokenMux.Lock()
- t.token = nil
- t.tokenMux.Unlock()
- atomic.StoreUint32(&t.disabled, 0)
- }
|