123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595 |
- // Copyright 2021 The Go Authors. All rights reserved.
- // Use of this source code is governed by a BSD-style
- // license that can be found in the LICENSE file.
- package externalaccount
- import (
- "bytes"
- "context"
- "crypto/hmac"
- "crypto/sha256"
- "encoding/hex"
- "encoding/json"
- "errors"
- "fmt"
- "io"
- "io/ioutil"
- "net/http"
- "net/url"
- "os"
- "path"
- "sort"
- "strings"
- "time"
- "golang.org/x/oauth2"
- )
- type awsSecurityCredentials struct {
- AccessKeyID string `json:"AccessKeyID"`
- SecretAccessKey string `json:"SecretAccessKey"`
- SecurityToken string `json:"Token"`
- }
- // awsRequestSigner is a utility class to sign http requests using a AWS V4 signature.
- type awsRequestSigner struct {
- RegionName string
- AwsSecurityCredentials awsSecurityCredentials
- }
- // getenv aliases os.Getenv for testing
- var getenv = os.Getenv
- const (
- // AWS Signature Version 4 signing algorithm identifier.
- awsAlgorithm = "AWS4-HMAC-SHA256"
- // The termination string for the AWS credential scope value as defined in
- // https://docs.aws.amazon.com/general/latest/gr/sigv4-create-string-to-sign.html
- awsRequestType = "aws4_request"
- // The AWS authorization header name for the security session token if available.
- awsSecurityTokenHeader = "x-amz-security-token"
- // The name of the header containing the session token for metadata endpoint calls
- awsIMDSv2SessionTokenHeader = "X-aws-ec2-metadata-token"
- awsIMDSv2SessionTtlHeader = "X-aws-ec2-metadata-token-ttl-seconds"
- awsIMDSv2SessionTtl = "300"
- // The AWS authorization header name for the auto-generated date.
- awsDateHeader = "x-amz-date"
- // Supported AWS configuration environment variables.
- awsAccessKeyId = "AWS_ACCESS_KEY_ID"
- awsDefaultRegion = "AWS_DEFAULT_REGION"
- awsRegion = "AWS_REGION"
- awsSecretAccessKey = "AWS_SECRET_ACCESS_KEY"
- awsSessionToken = "AWS_SESSION_TOKEN"
- awsTimeFormatLong = "20060102T150405Z"
- awsTimeFormatShort = "20060102"
- )
- func getSha256(input []byte) (string, error) {
- hash := sha256.New()
- if _, err := hash.Write(input); err != nil {
- return "", err
- }
- return hex.EncodeToString(hash.Sum(nil)), nil
- }
- func getHmacSha256(key, input []byte) ([]byte, error) {
- hash := hmac.New(sha256.New, key)
- if _, err := hash.Write(input); err != nil {
- return nil, err
- }
- return hash.Sum(nil), nil
- }
- func cloneRequest(r *http.Request) *http.Request {
- r2 := new(http.Request)
- *r2 = *r
- if r.Header != nil {
- r2.Header = make(http.Header, len(r.Header))
- // Find total number of values.
- headerCount := 0
- for _, headerValues := range r.Header {
- headerCount += len(headerValues)
- }
- copiedHeaders := make([]string, headerCount) // shared backing array for headers' values
- for headerKey, headerValues := range r.Header {
- headerCount = copy(copiedHeaders, headerValues)
- r2.Header[headerKey] = copiedHeaders[:headerCount:headerCount]
- copiedHeaders = copiedHeaders[headerCount:]
- }
- }
- return r2
- }
- func canonicalPath(req *http.Request) string {
- result := req.URL.EscapedPath()
- if result == "" {
- return "/"
- }
- return path.Clean(result)
- }
- func canonicalQuery(req *http.Request) string {
- queryValues := req.URL.Query()
- for queryKey := range queryValues {
- sort.Strings(queryValues[queryKey])
- }
- return queryValues.Encode()
- }
- func canonicalHeaders(req *http.Request) (string, string) {
- // Header keys need to be sorted alphabetically.
- var headers []string
- lowerCaseHeaders := make(http.Header)
- for k, v := range req.Header {
- k := strings.ToLower(k)
- if _, ok := lowerCaseHeaders[k]; ok {
- // include additional values
- lowerCaseHeaders[k] = append(lowerCaseHeaders[k], v...)
- } else {
- headers = append(headers, k)
- lowerCaseHeaders[k] = v
- }
- }
- sort.Strings(headers)
- var fullHeaders bytes.Buffer
- for _, header := range headers {
- headerValue := strings.Join(lowerCaseHeaders[header], ",")
- fullHeaders.WriteString(header)
- fullHeaders.WriteRune(':')
- fullHeaders.WriteString(headerValue)
- fullHeaders.WriteRune('\n')
- }
- return strings.Join(headers, ";"), fullHeaders.String()
- }
- func requestDataHash(req *http.Request) (string, error) {
- var requestData []byte
- if req.Body != nil {
- requestBody, err := req.GetBody()
- if err != nil {
- return "", err
- }
- defer requestBody.Close()
- requestData, err = ioutil.ReadAll(io.LimitReader(requestBody, 1<<20))
- if err != nil {
- return "", err
- }
- }
- return getSha256(requestData)
- }
- func requestHost(req *http.Request) string {
- if req.Host != "" {
- return req.Host
- }
- return req.URL.Host
- }
- func canonicalRequest(req *http.Request, canonicalHeaderColumns, canonicalHeaderData string) (string, error) {
- dataHash, err := requestDataHash(req)
- if err != nil {
- return "", err
- }
- return fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s", req.Method, canonicalPath(req), canonicalQuery(req), canonicalHeaderData, canonicalHeaderColumns, dataHash), nil
- }
- // SignRequest adds the appropriate headers to an http.Request
- // or returns an error if something prevented this.
- func (rs *awsRequestSigner) SignRequest(req *http.Request) error {
- signedRequest := cloneRequest(req)
- timestamp := now()
- signedRequest.Header.Add("host", requestHost(req))
- if rs.AwsSecurityCredentials.SecurityToken != "" {
- signedRequest.Header.Add(awsSecurityTokenHeader, rs.AwsSecurityCredentials.SecurityToken)
- }
- if signedRequest.Header.Get("date") == "" {
- signedRequest.Header.Add(awsDateHeader, timestamp.Format(awsTimeFormatLong))
- }
- authorizationCode, err := rs.generateAuthentication(signedRequest, timestamp)
- if err != nil {
- return err
- }
- signedRequest.Header.Set("Authorization", authorizationCode)
- req.Header = signedRequest.Header
- return nil
- }
- func (rs *awsRequestSigner) generateAuthentication(req *http.Request, timestamp time.Time) (string, error) {
- canonicalHeaderColumns, canonicalHeaderData := canonicalHeaders(req)
- dateStamp := timestamp.Format(awsTimeFormatShort)
- serviceName := ""
- if splitHost := strings.Split(requestHost(req), "."); len(splitHost) > 0 {
- serviceName = splitHost[0]
- }
- credentialScope := fmt.Sprintf("%s/%s/%s/%s", dateStamp, rs.RegionName, serviceName, awsRequestType)
- requestString, err := canonicalRequest(req, canonicalHeaderColumns, canonicalHeaderData)
- if err != nil {
- return "", err
- }
- requestHash, err := getSha256([]byte(requestString))
- if err != nil {
- return "", err
- }
- stringToSign := fmt.Sprintf("%s\n%s\n%s\n%s", awsAlgorithm, timestamp.Format(awsTimeFormatLong), credentialScope, requestHash)
- signingKey := []byte("AWS4" + rs.AwsSecurityCredentials.SecretAccessKey)
- for _, signingInput := range []string{
- dateStamp, rs.RegionName, serviceName, awsRequestType, stringToSign,
- } {
- signingKey, err = getHmacSha256(signingKey, []byte(signingInput))
- if err != nil {
- return "", err
- }
- }
- return fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s", awsAlgorithm, rs.AwsSecurityCredentials.AccessKeyID, credentialScope, canonicalHeaderColumns, hex.EncodeToString(signingKey)), nil
- }
- type awsCredentialSource struct {
- EnvironmentID string
- RegionURL string
- RegionalCredVerificationURL string
- CredVerificationURL string
- IMDSv2SessionTokenURL string
- TargetResource string
- requestSigner *awsRequestSigner
- region string
- ctx context.Context
- client *http.Client
- }
- type awsRequestHeader struct {
- Key string `json:"key"`
- Value string `json:"value"`
- }
- type awsRequest struct {
- URL string `json:"url"`
- Method string `json:"method"`
- Headers []awsRequestHeader `json:"headers"`
- }
- func (cs awsCredentialSource) validateMetadataServers() error {
- if err := cs.validateMetadataServer(cs.RegionURL, "region_url"); err != nil {
- return err
- }
- if err := cs.validateMetadataServer(cs.CredVerificationURL, "url"); err != nil {
- return err
- }
- return cs.validateMetadataServer(cs.IMDSv2SessionTokenURL, "imdsv2_session_token_url")
- }
- var validHostnames []string = []string{"169.254.169.254", "fd00:ec2::254"}
- func (cs awsCredentialSource) isValidMetadataServer(metadataUrl string) bool {
- if metadataUrl == "" {
- // Zero value means use default, which is valid.
- return true
- }
- u, err := url.Parse(metadataUrl)
- if err != nil {
- // Unparseable URL means invalid
- return false
- }
- for _, validHostname := range validHostnames {
- if u.Hostname() == validHostname {
- // If it's one of the valid hostnames, everything is good
- return true
- }
- }
- // hostname not found in our allowlist, so not valid
- return false
- }
- func (cs awsCredentialSource) validateMetadataServer(metadataUrl, urlName string) error {
- if !cs.isValidMetadataServer(metadataUrl) {
- return fmt.Errorf("oauth2/google: invalid hostname %s for %s", metadataUrl, urlName)
- }
- return nil
- }
- func (cs awsCredentialSource) doRequest(req *http.Request) (*http.Response, error) {
- if cs.client == nil {
- cs.client = oauth2.NewClient(cs.ctx, nil)
- }
- return cs.client.Do(req.WithContext(cs.ctx))
- }
- func canRetrieveRegionFromEnvironment() bool {
- // The AWS region can be provided through AWS_REGION or AWS_DEFAULT_REGION. Only one is
- // required.
- return getenv(awsRegion) != "" || getenv(awsDefaultRegion) != ""
- }
- func canRetrieveSecurityCredentialFromEnvironment() bool {
- // Check if both AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY are available.
- return getenv(awsAccessKeyId) != "" && getenv(awsSecretAccessKey) != ""
- }
- func shouldUseMetadataServer() bool {
- return !canRetrieveRegionFromEnvironment() || !canRetrieveSecurityCredentialFromEnvironment()
- }
- func (cs awsCredentialSource) subjectToken() (string, error) {
- if cs.requestSigner == nil {
- headers := make(map[string]string)
- if shouldUseMetadataServer() {
- awsSessionToken, err := cs.getAWSSessionToken()
- if err != nil {
- return "", err
- }
- if awsSessionToken != "" {
- headers[awsIMDSv2SessionTokenHeader] = awsSessionToken
- }
- }
- awsSecurityCredentials, err := cs.getSecurityCredentials(headers)
- if err != nil {
- return "", err
- }
- if cs.region, err = cs.getRegion(headers); err != nil {
- return "", err
- }
- cs.requestSigner = &awsRequestSigner{
- RegionName: cs.region,
- AwsSecurityCredentials: awsSecurityCredentials,
- }
- }
- // Generate the signed request to AWS STS GetCallerIdentity API.
- // Use the required regional endpoint. Otherwise, the request will fail.
- req, err := http.NewRequest("POST", strings.Replace(cs.RegionalCredVerificationURL, "{region}", cs.region, 1), nil)
- if err != nil {
- return "", err
- }
- // The full, canonical resource name of the workload identity pool
- // provider, with or without the HTTPS prefix.
- // Including this header as part of the signature is recommended to
- // ensure data integrity.
- if cs.TargetResource != "" {
- req.Header.Add("x-goog-cloud-target-resource", cs.TargetResource)
- }
- cs.requestSigner.SignRequest(req)
- /*
- The GCP STS endpoint expects the headers to be formatted as:
- # [
- # {key: 'x-amz-date', value: '...'},
- # {key: 'Authorization', value: '...'},
- # ...
- # ]
- # And then serialized as:
- # quote(json.dumps({
- # url: '...',
- # method: 'POST',
- # headers: [{key: 'x-amz-date', value: '...'}, ...]
- # }))
- */
- awsSignedReq := awsRequest{
- URL: req.URL.String(),
- Method: "POST",
- }
- for headerKey, headerList := range req.Header {
- for _, headerValue := range headerList {
- awsSignedReq.Headers = append(awsSignedReq.Headers, awsRequestHeader{
- Key: headerKey,
- Value: headerValue,
- })
- }
- }
- sort.Slice(awsSignedReq.Headers, func(i, j int) bool {
- headerCompare := strings.Compare(awsSignedReq.Headers[i].Key, awsSignedReq.Headers[j].Key)
- if headerCompare == 0 {
- return strings.Compare(awsSignedReq.Headers[i].Value, awsSignedReq.Headers[j].Value) < 0
- }
- return headerCompare < 0
- })
- result, err := json.Marshal(awsSignedReq)
- if err != nil {
- return "", err
- }
- return url.QueryEscape(string(result)), nil
- }
- func (cs *awsCredentialSource) getAWSSessionToken() (string, error) {
- if cs.IMDSv2SessionTokenURL == "" {
- return "", nil
- }
- req, err := http.NewRequest("PUT", cs.IMDSv2SessionTokenURL, nil)
- if err != nil {
- return "", err
- }
- req.Header.Add(awsIMDSv2SessionTtlHeader, awsIMDSv2SessionTtl)
- resp, err := cs.doRequest(req)
- if err != nil {
- return "", err
- }
- defer resp.Body.Close()
- respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
- if err != nil {
- return "", err
- }
- if resp.StatusCode != 200 {
- return "", fmt.Errorf("oauth2/google: unable to retrieve AWS session token - %s", string(respBody))
- }
- return string(respBody), nil
- }
- func (cs *awsCredentialSource) getRegion(headers map[string]string) (string, error) {
- if canRetrieveRegionFromEnvironment() {
- if envAwsRegion := getenv(awsRegion); envAwsRegion != "" {
- return envAwsRegion, nil
- }
- return getenv("AWS_DEFAULT_REGION"), nil
- }
- if cs.RegionURL == "" {
- return "", errors.New("oauth2/google: unable to determine AWS region")
- }
- req, err := http.NewRequest("GET", cs.RegionURL, nil)
- if err != nil {
- return "", err
- }
- for name, value := range headers {
- req.Header.Add(name, value)
- }
- resp, err := cs.doRequest(req)
- if err != nil {
- return "", err
- }
- defer resp.Body.Close()
- respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
- if err != nil {
- return "", err
- }
- if resp.StatusCode != 200 {
- return "", fmt.Errorf("oauth2/google: unable to retrieve AWS region - %s", string(respBody))
- }
- // This endpoint will return the region in format: us-east-2b.
- // Only the us-east-2 part should be used.
- respBodyEnd := 0
- if len(respBody) > 1 {
- respBodyEnd = len(respBody) - 1
- }
- return string(respBody[:respBodyEnd]), nil
- }
- func (cs *awsCredentialSource) getSecurityCredentials(headers map[string]string) (result awsSecurityCredentials, err error) {
- if canRetrieveSecurityCredentialFromEnvironment() {
- return awsSecurityCredentials{
- AccessKeyID: getenv(awsAccessKeyId),
- SecretAccessKey: getenv(awsSecretAccessKey),
- SecurityToken: getenv(awsSessionToken),
- }, nil
- }
- roleName, err := cs.getMetadataRoleName(headers)
- if err != nil {
- return
- }
- credentials, err := cs.getMetadataSecurityCredentials(roleName, headers)
- if err != nil {
- return
- }
- if credentials.AccessKeyID == "" {
- return result, errors.New("oauth2/google: missing AccessKeyId credential")
- }
- if credentials.SecretAccessKey == "" {
- return result, errors.New("oauth2/google: missing SecretAccessKey credential")
- }
- return credentials, nil
- }
- func (cs *awsCredentialSource) getMetadataSecurityCredentials(roleName string, headers map[string]string) (awsSecurityCredentials, error) {
- var result awsSecurityCredentials
- req, err := http.NewRequest("GET", fmt.Sprintf("%s/%s", cs.CredVerificationURL, roleName), nil)
- if err != nil {
- return result, err
- }
- req.Header.Add("Content-Type", "application/json")
- for name, value := range headers {
- req.Header.Add(name, value)
- }
- resp, err := cs.doRequest(req)
- if err != nil {
- return result, err
- }
- defer resp.Body.Close()
- respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
- if err != nil {
- return result, err
- }
- if resp.StatusCode != 200 {
- return result, fmt.Errorf("oauth2/google: unable to retrieve AWS security credentials - %s", string(respBody))
- }
- err = json.Unmarshal(respBody, &result)
- return result, err
- }
- func (cs *awsCredentialSource) getMetadataRoleName(headers map[string]string) (string, error) {
- if cs.CredVerificationURL == "" {
- return "", errors.New("oauth2/google: unable to determine the AWS metadata server security credentials endpoint")
- }
- req, err := http.NewRequest("GET", cs.CredVerificationURL, nil)
- if err != nil {
- return "", err
- }
- for name, value := range headers {
- req.Header.Add(name, value)
- }
- resp, err := cs.doRequest(req)
- if err != nil {
- return "", err
- }
- defer resp.Body.Close()
- respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
- if err != nil {
- return "", err
- }
- if resp.StatusCode != 200 {
- return "", fmt.Errorf("oauth2/google: unable to retrieve AWS role name - %s", string(respBody))
- }
- return string(respBody), nil
- }
|