123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229 |
- package ec2rolecreds
- import (
- "bufio"
- "context"
- "encoding/json"
- "fmt"
- "math"
- "path"
- "strings"
- "time"
- "github.com/aws/aws-sdk-go-v2/aws"
- "github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
- sdkrand "github.com/aws/aws-sdk-go-v2/internal/rand"
- "github.com/aws/aws-sdk-go-v2/internal/sdk"
- "github.com/aws/smithy-go"
- "github.com/aws/smithy-go/logging"
- "github.com/aws/smithy-go/middleware"
- )
- // ProviderName provides a name of EC2Role provider
- const ProviderName = "EC2RoleProvider"
- // GetMetadataAPIClient provides the interface for an EC2 IMDS API client for the
- // GetMetadata operation.
- type GetMetadataAPIClient interface {
- GetMetadata(context.Context, *imds.GetMetadataInput, ...func(*imds.Options)) (*imds.GetMetadataOutput, error)
- }
- // A Provider retrieves credentials from the EC2 service, and keeps track if
- // those credentials are expired.
- //
- // The New function must be used to create the with a custom EC2 IMDS client.
- //
- // p := &ec2rolecreds.New(func(o *ec2rolecreds.Options{
- // o.Client = imds.New(imds.Options{/* custom options */})
- // })
- type Provider struct {
- options Options
- }
- // Options is a list of user settable options for setting the behavior of the Provider.
- type Options struct {
- // The API client that will be used by the provider to make GetMetadata API
- // calls to EC2 IMDS.
- //
- // If nil, the provider will default to the EC2 IMDS client.
- Client GetMetadataAPIClient
- }
- // New returns an initialized Provider value configured to retrieve
- // credentials from EC2 Instance Metadata service.
- func New(optFns ...func(*Options)) *Provider {
- options := Options{}
- for _, fn := range optFns {
- fn(&options)
- }
- if options.Client == nil {
- options.Client = imds.New(imds.Options{})
- }
- return &Provider{
- options: options,
- }
- }
- // Retrieve retrieves credentials from the EC2 service. Error will be returned
- // if the request fails, or unable to extract the desired credentials.
- func (p *Provider) Retrieve(ctx context.Context) (aws.Credentials, error) {
- credsList, err := requestCredList(ctx, p.options.Client)
- if err != nil {
- return aws.Credentials{Source: ProviderName}, err
- }
- if len(credsList) == 0 {
- return aws.Credentials{Source: ProviderName},
- fmt.Errorf("unexpected empty EC2 IMDS role list")
- }
- credsName := credsList[0]
- roleCreds, err := requestCred(ctx, p.options.Client, credsName)
- if err != nil {
- return aws.Credentials{Source: ProviderName}, err
- }
- creds := aws.Credentials{
- AccessKeyID: roleCreds.AccessKeyID,
- SecretAccessKey: roleCreds.SecretAccessKey,
- SessionToken: roleCreds.Token,
- Source: ProviderName,
- CanExpire: true,
- Expires: roleCreds.Expiration,
- }
- // Cap role credentials Expires to 1 hour so they can be refreshed more
- // often. Jitter will be applied credentials cache if being used.
- if anHour := sdk.NowTime().Add(1 * time.Hour); creds.Expires.After(anHour) {
- creds.Expires = anHour
- }
- return creds, nil
- }
- // HandleFailToRefresh will extend the credentials Expires time if it it is
- // expired. If the credentials will not expire within the minimum time, they
- // will be returned.
- //
- // If the credentials cannot expire, the original error will be returned.
- func (p *Provider) HandleFailToRefresh(ctx context.Context, prevCreds aws.Credentials, err error) (
- aws.Credentials, error,
- ) {
- if !prevCreds.CanExpire {
- return aws.Credentials{}, err
- }
- if prevCreds.Expires.After(sdk.NowTime().Add(5 * time.Minute)) {
- return prevCreds, nil
- }
- newCreds := prevCreds
- randFloat64, err := sdkrand.CryptoRandFloat64()
- if err != nil {
- return aws.Credentials{}, fmt.Errorf("failed to get random float, %w", err)
- }
- // Random distribution of [5,15) minutes.
- expireOffset := time.Duration(randFloat64*float64(10*time.Minute)) + 5*time.Minute
- newCreds.Expires = sdk.NowTime().Add(expireOffset)
- logger := middleware.GetLogger(ctx)
- logger.Logf(logging.Warn, "Attempting credential expiration extension due to a credential service availability issue. A refresh of these credentials will be attempted again in %v minutes.", math.Floor(expireOffset.Minutes()))
- return newCreds, nil
- }
- // AdjustExpiresBy will adds the passed in duration to the passed in
- // credential's Expires time, unless the time until Expires is less than 15
- // minutes. Returns the credentials, even if not updated.
- func (p *Provider) AdjustExpiresBy(creds aws.Credentials, dur time.Duration) (
- aws.Credentials, error,
- ) {
- if !creds.CanExpire {
- return creds, nil
- }
- if creds.Expires.Before(sdk.NowTime().Add(15 * time.Minute)) {
- return creds, nil
- }
- creds.Expires = creds.Expires.Add(dur)
- return creds, nil
- }
- // ec2RoleCredRespBody provides the shape for unmarshaling credential
- // request responses.
- type ec2RoleCredRespBody struct {
- // Success State
- Expiration time.Time
- AccessKeyID string
- SecretAccessKey string
- Token string
- // Error state
- Code string
- Message string
- }
- const iamSecurityCredsPath = "/iam/security-credentials/"
- // requestCredList requests a list of credentials from the EC2 service. If
- // there are no credentials, or there is an error making or receiving the
- // request
- func requestCredList(ctx context.Context, client GetMetadataAPIClient) ([]string, error) {
- resp, err := client.GetMetadata(ctx, &imds.GetMetadataInput{
- Path: iamSecurityCredsPath,
- })
- if err != nil {
- return nil, fmt.Errorf("no EC2 IMDS role found, %w", err)
- }
- defer resp.Content.Close()
- credsList := []string{}
- s := bufio.NewScanner(resp.Content)
- for s.Scan() {
- credsList = append(credsList, s.Text())
- }
- if err := s.Err(); err != nil {
- return nil, fmt.Errorf("failed to read EC2 IMDS role, %w", err)
- }
- return credsList, nil
- }
- // requestCred requests the credentials for a specific credentials from the EC2 service.
- //
- // If the credentials cannot be found, or there is an error reading the response
- // and error will be returned.
- func requestCred(ctx context.Context, client GetMetadataAPIClient, credsName string) (ec2RoleCredRespBody, error) {
- resp, err := client.GetMetadata(ctx, &imds.GetMetadataInput{
- Path: path.Join(iamSecurityCredsPath, credsName),
- })
- if err != nil {
- return ec2RoleCredRespBody{},
- fmt.Errorf("failed to get %s EC2 IMDS role credentials, %w",
- credsName, err)
- }
- defer resp.Content.Close()
- var respCreds ec2RoleCredRespBody
- if err := json.NewDecoder(resp.Content).Decode(&respCreds); err != nil {
- return ec2RoleCredRespBody{},
- fmt.Errorf("failed to decode %s EC2 IMDS role credentials, %w",
- credsName, err)
- }
- if !strings.EqualFold(respCreds.Code, "Success") {
- // If an error code was returned something failed requesting the role.
- return ec2RoleCredRespBody{},
- fmt.Errorf("failed to get %s EC2 IMDS role credentials, %w",
- credsName,
- &smithy.GenericAPIError{Code: respCreds.Code, Message: respCreds.Message})
- }
- return respCreds, nil
- }
|