Browse Source

Merge pull request #35055 from adnxn/creds-endpoint

Add credentials endpoint option for awslogs driver
Michael Crosby 7 năm trước cách đây
mục cha
commit
158c072bde

+ 42 - 11
daemon/logger/awslogs/cloudwatchlogs.go

@@ -14,6 +14,7 @@ import (
 
 	"github.com/aws/aws-sdk-go/aws"
 	"github.com/aws/aws-sdk-go/aws/awserr"
+	"github.com/aws/aws-sdk-go/aws/credentials/endpointcreds"
 	"github.com/aws/aws-sdk-go/aws/ec2metadata"
 	"github.com/aws/aws-sdk-go/aws/request"
 	"github.com/aws/aws-sdk-go/aws/session"
@@ -26,16 +27,17 @@ import (
 )
 
 const (
-	name                  = "awslogs"
-	regionKey             = "awslogs-region"
-	regionEnvKey          = "AWS_REGION"
-	logGroupKey           = "awslogs-group"
-	logStreamKey          = "awslogs-stream"
-	logCreateGroupKey     = "awslogs-create-group"
-	tagKey                = "tag"
-	datetimeFormatKey     = "awslogs-datetime-format"
-	multilinePatternKey   = "awslogs-multiline-pattern"
-	batchPublishFrequency = 5 * time.Second
+	name                   = "awslogs"
+	regionKey              = "awslogs-region"
+	regionEnvKey           = "AWS_REGION"
+	logGroupKey            = "awslogs-group"
+	logStreamKey           = "awslogs-stream"
+	logCreateGroupKey      = "awslogs-create-group"
+	tagKey                 = "tag"
+	datetimeFormatKey      = "awslogs-datetime-format"
+	multilinePatternKey    = "awslogs-multiline-pattern"
+	credentialsEndpointKey = "awslogs-credentials-endpoint"
+	batchPublishFrequency  = 5 * time.Second
 
 	// See: http://docs.aws.amazon.com/AmazonCloudWatchLogs/latest/APIReference/API_PutLogEvents.html
 	perEventBytes          = 26
@@ -50,6 +52,8 @@ const (
 	invalidSequenceTokenCode  = "InvalidSequenceTokenException"
 	resourceNotFoundCode      = "ResourceNotFoundException"
 
+	credentialsEndpoint = "http://169.254.170.2"
+
 	userAgentHeader = "User-Agent"
 )
 
@@ -198,6 +202,10 @@ var newRegionFinder = func() regionFinder {
 	return ec2metadata.New(session.New())
 }
 
+// newSDKEndpoint is a variable such that the implementation
+// can be swapped out for unit tests.
+var newSDKEndpoint = credentialsEndpoint
+
 // newAWSLogsClient creates the service client for Amazon CloudWatch Logs.
 // Customizations to the default client from the SDK include a Docker-specific
 // User-Agent string and automatic region detection using the EC2 Instance
@@ -222,11 +230,33 @@ func newAWSLogsClient(info logger.Info) (api, error) {
 		}
 		region = &r
 	}
+
+	sess, err := session.NewSession()
+	if err != nil {
+		return nil, errors.New("Failed to create a service client session for for awslogs driver")
+	}
+
+	// attach region to cloudwatchlogs config
+	sess.Config.Region = region
+
+	if uri, ok := info.Config[credentialsEndpointKey]; ok {
+		logrus.Debugf("Trying to get credentials from awslogs-credentials-endpoint")
+
+		endpoint := fmt.Sprintf("%s%s", newSDKEndpoint, uri)
+		creds := endpointcreds.NewCredentialsClient(*sess.Config, sess.Handlers, endpoint,
+			func(p *endpointcreds.Provider) {
+				p.ExpiryWindow = 5 * time.Minute
+			})
+
+		// attach credentials to cloudwatchlogs config
+		sess.Config.Credentials = creds
+	}
+
 	logrus.WithFields(logrus.Fields{
 		"region": *region,
 	}).Debug("Created awslogs client")
 
-	client := cloudwatchlogs.New(session.New(), aws.NewConfig().WithRegion(*region))
+	client := cloudwatchlogs.New(sess)
 
 	client.Handlers.Build.PushBackNamed(request.NamedHandler{
 		Name: "DockerUserAgentHandler",
@@ -525,6 +555,7 @@ func ValidateLogOpt(cfg map[string]string) error {
 		case tagKey:
 		case datetimeFormatKey:
 		case multilinePatternKey:
+		case credentialsEndpointKey:
 		default:
 			return fmt.Errorf("unknown log opt '%s' for %s log driver", key, name)
 		}

+ 121 - 0
daemon/logger/awslogs/cloudwatchlogs_test.go

@@ -3,7 +3,10 @@ package awslogs
 import (
 	"errors"
 	"fmt"
+	"io/ioutil"
 	"net/http"
+	"net/http/httptest"
+	"os"
 	"reflect"
 	"regexp"
 	"runtime"
@@ -1065,3 +1068,121 @@ func BenchmarkUnwrapEvents(b *testing.B) {
 		as.Len(res, maximumLogEventsPerPut)
 	}
 }
+
+func TestNewAWSLogsClientCredentialEndpointDetect(t *testing.T) {
+	// required for the cloudwatchlogs client
+	os.Setenv("AWS_REGION", "us-west-2")
+	defer os.Unsetenv("AWS_REGION")
+
+	credsResp := `{
+		"AccessKeyId" :    "test-access-key-id",
+		"SecretAccessKey": "test-secret-access-key"
+		}`
+
+	expectedAccessKeyID := "test-access-key-id"
+	expectedSecretAccessKey := "test-secret-access-key"
+
+	testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		w.Header().Set("Content-Type", "application/json")
+		fmt.Fprintln(w, credsResp)
+	}))
+	defer testServer.Close()
+
+	// set the SDKEndpoint in the driver
+	newSDKEndpoint = testServer.URL
+
+	info := logger.Info{
+		Config: map[string]string{},
+	}
+
+	info.Config["awslogs-credentials-endpoint"] = "/creds"
+
+	c, err := newAWSLogsClient(info)
+	assert.NoError(t, err)
+
+	client := c.(*cloudwatchlogs.CloudWatchLogs)
+
+	creds, err := client.Config.Credentials.Get()
+	assert.NoError(t, err)
+
+	assert.Equal(t, expectedAccessKeyID, creds.AccessKeyID)
+	assert.Equal(t, expectedSecretAccessKey, creds.SecretAccessKey)
+}
+
+func TestNewAWSLogsClientCredentialEnvironmentVariable(t *testing.T) {
+	// required for the cloudwatchlogs client
+	os.Setenv("AWS_REGION", "us-west-2")
+	defer os.Unsetenv("AWS_REGION")
+
+	expectedAccessKeyID := "test-access-key-id"
+	expectedSecretAccessKey := "test-secret-access-key"
+
+	os.Setenv("AWS_ACCESS_KEY_ID", expectedAccessKeyID)
+	defer os.Unsetenv("AWS_ACCESS_KEY_ID")
+
+	os.Setenv("AWS_SECRET_ACCESS_KEY", expectedSecretAccessKey)
+	defer os.Unsetenv("AWS_SECRET_ACCESS_KEY")
+
+	info := logger.Info{
+		Config: map[string]string{},
+	}
+
+	c, err := newAWSLogsClient(info)
+	assert.NoError(t, err)
+
+	client := c.(*cloudwatchlogs.CloudWatchLogs)
+
+	creds, err := client.Config.Credentials.Get()
+	assert.NoError(t, err)
+
+	assert.Equal(t, expectedAccessKeyID, creds.AccessKeyID)
+	assert.Equal(t, expectedSecretAccessKey, creds.SecretAccessKey)
+
+}
+
+func TestNewAWSLogsClientCredentialSharedFile(t *testing.T) {
+	// required for the cloudwatchlogs client
+	os.Setenv("AWS_REGION", "us-west-2")
+	defer os.Unsetenv("AWS_REGION")
+
+	expectedAccessKeyID := "test-access-key-id"
+	expectedSecretAccessKey := "test-secret-access-key"
+
+	contentStr := `
+	[default]
+	aws_access_key_id = "test-access-key-id"
+	aws_secret_access_key =  "test-secret-access-key"
+	`
+	content := []byte(contentStr)
+
+	tmpfile, err := ioutil.TempFile("", "example")
+	defer os.Remove(tmpfile.Name()) // clean up
+	assert.NoError(t, err)
+
+	_, err = tmpfile.Write(content)
+	assert.NoError(t, err)
+
+	err = tmpfile.Close()
+	assert.NoError(t, err)
+
+	os.Unsetenv("AWS_ACCESS_KEY_ID")
+	os.Unsetenv("AWS_SECRET_ACCESS_KEY")
+
+	os.Setenv("AWS_SHARED_CREDENTIALS_FILE", tmpfile.Name())
+	defer os.Unsetenv("AWS_SHARED_CREDENTIALS_FILE")
+
+	info := logger.Info{
+		Config: map[string]string{},
+	}
+
+	c, err := newAWSLogsClient(info)
+	assert.NoError(t, err)
+
+	client := c.(*cloudwatchlogs.CloudWatchLogs)
+
+	creds, err := client.Config.Credentials.Get()
+	assert.NoError(t, err)
+
+	assert.Equal(t, expectedAccessKeyID, creds.AccessKeyID)
+	assert.Equal(t, expectedSecretAccessKey, creds.SecretAccessKey)
+}