aws.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595
  1. // Copyright 2021 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package externalaccount
  5. import (
  6. "bytes"
  7. "context"
  8. "crypto/hmac"
  9. "crypto/sha256"
  10. "encoding/hex"
  11. "encoding/json"
  12. "errors"
  13. "fmt"
  14. "io"
  15. "io/ioutil"
  16. "net/http"
  17. "net/url"
  18. "os"
  19. "path"
  20. "sort"
  21. "strings"
  22. "time"
  23. "golang.org/x/oauth2"
  24. )
  25. type awsSecurityCredentials struct {
  26. AccessKeyID string `json:"AccessKeyID"`
  27. SecretAccessKey string `json:"SecretAccessKey"`
  28. SecurityToken string `json:"Token"`
  29. }
  30. // awsRequestSigner is a utility class to sign http requests using a AWS V4 signature.
  31. type awsRequestSigner struct {
  32. RegionName string
  33. AwsSecurityCredentials awsSecurityCredentials
  34. }
  35. // getenv aliases os.Getenv for testing
  36. var getenv = os.Getenv
  37. const (
  38. // AWS Signature Version 4 signing algorithm identifier.
  39. awsAlgorithm = "AWS4-HMAC-SHA256"
  40. // The termination string for the AWS credential scope value as defined in
  41. // https://docs.aws.amazon.com/general/latest/gr/sigv4-create-string-to-sign.html
  42. awsRequestType = "aws4_request"
  43. // The AWS authorization header name for the security session token if available.
  44. awsSecurityTokenHeader = "x-amz-security-token"
  45. // The name of the header containing the session token for metadata endpoint calls
  46. awsIMDSv2SessionTokenHeader = "X-aws-ec2-metadata-token"
  47. awsIMDSv2SessionTtlHeader = "X-aws-ec2-metadata-token-ttl-seconds"
  48. awsIMDSv2SessionTtl = "300"
  49. // The AWS authorization header name for the auto-generated date.
  50. awsDateHeader = "x-amz-date"
  51. // Supported AWS configuration environment variables.
  52. awsAccessKeyId = "AWS_ACCESS_KEY_ID"
  53. awsDefaultRegion = "AWS_DEFAULT_REGION"
  54. awsRegion = "AWS_REGION"
  55. awsSecretAccessKey = "AWS_SECRET_ACCESS_KEY"
  56. awsSessionToken = "AWS_SESSION_TOKEN"
  57. awsTimeFormatLong = "20060102T150405Z"
  58. awsTimeFormatShort = "20060102"
  59. )
  60. func getSha256(input []byte) (string, error) {
  61. hash := sha256.New()
  62. if _, err := hash.Write(input); err != nil {
  63. return "", err
  64. }
  65. return hex.EncodeToString(hash.Sum(nil)), nil
  66. }
  67. func getHmacSha256(key, input []byte) ([]byte, error) {
  68. hash := hmac.New(sha256.New, key)
  69. if _, err := hash.Write(input); err != nil {
  70. return nil, err
  71. }
  72. return hash.Sum(nil), nil
  73. }
  74. func cloneRequest(r *http.Request) *http.Request {
  75. r2 := new(http.Request)
  76. *r2 = *r
  77. if r.Header != nil {
  78. r2.Header = make(http.Header, len(r.Header))
  79. // Find total number of values.
  80. headerCount := 0
  81. for _, headerValues := range r.Header {
  82. headerCount += len(headerValues)
  83. }
  84. copiedHeaders := make([]string, headerCount) // shared backing array for headers' values
  85. for headerKey, headerValues := range r.Header {
  86. headerCount = copy(copiedHeaders, headerValues)
  87. r2.Header[headerKey] = copiedHeaders[:headerCount:headerCount]
  88. copiedHeaders = copiedHeaders[headerCount:]
  89. }
  90. }
  91. return r2
  92. }
  93. func canonicalPath(req *http.Request) string {
  94. result := req.URL.EscapedPath()
  95. if result == "" {
  96. return "/"
  97. }
  98. return path.Clean(result)
  99. }
  100. func canonicalQuery(req *http.Request) string {
  101. queryValues := req.URL.Query()
  102. for queryKey := range queryValues {
  103. sort.Strings(queryValues[queryKey])
  104. }
  105. return queryValues.Encode()
  106. }
  107. func canonicalHeaders(req *http.Request) (string, string) {
  108. // Header keys need to be sorted alphabetically.
  109. var headers []string
  110. lowerCaseHeaders := make(http.Header)
  111. for k, v := range req.Header {
  112. k := strings.ToLower(k)
  113. if _, ok := lowerCaseHeaders[k]; ok {
  114. // include additional values
  115. lowerCaseHeaders[k] = append(lowerCaseHeaders[k], v...)
  116. } else {
  117. headers = append(headers, k)
  118. lowerCaseHeaders[k] = v
  119. }
  120. }
  121. sort.Strings(headers)
  122. var fullHeaders bytes.Buffer
  123. for _, header := range headers {
  124. headerValue := strings.Join(lowerCaseHeaders[header], ",")
  125. fullHeaders.WriteString(header)
  126. fullHeaders.WriteRune(':')
  127. fullHeaders.WriteString(headerValue)
  128. fullHeaders.WriteRune('\n')
  129. }
  130. return strings.Join(headers, ";"), fullHeaders.String()
  131. }
  132. func requestDataHash(req *http.Request) (string, error) {
  133. var requestData []byte
  134. if req.Body != nil {
  135. requestBody, err := req.GetBody()
  136. if err != nil {
  137. return "", err
  138. }
  139. defer requestBody.Close()
  140. requestData, err = ioutil.ReadAll(io.LimitReader(requestBody, 1<<20))
  141. if err != nil {
  142. return "", err
  143. }
  144. }
  145. return getSha256(requestData)
  146. }
  147. func requestHost(req *http.Request) string {
  148. if req.Host != "" {
  149. return req.Host
  150. }
  151. return req.URL.Host
  152. }
  153. func canonicalRequest(req *http.Request, canonicalHeaderColumns, canonicalHeaderData string) (string, error) {
  154. dataHash, err := requestDataHash(req)
  155. if err != nil {
  156. return "", err
  157. }
  158. return fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s", req.Method, canonicalPath(req), canonicalQuery(req), canonicalHeaderData, canonicalHeaderColumns, dataHash), nil
  159. }
  160. // SignRequest adds the appropriate headers to an http.Request
  161. // or returns an error if something prevented this.
  162. func (rs *awsRequestSigner) SignRequest(req *http.Request) error {
  163. signedRequest := cloneRequest(req)
  164. timestamp := now()
  165. signedRequest.Header.Add("host", requestHost(req))
  166. if rs.AwsSecurityCredentials.SecurityToken != "" {
  167. signedRequest.Header.Add(awsSecurityTokenHeader, rs.AwsSecurityCredentials.SecurityToken)
  168. }
  169. if signedRequest.Header.Get("date") == "" {
  170. signedRequest.Header.Add(awsDateHeader, timestamp.Format(awsTimeFormatLong))
  171. }
  172. authorizationCode, err := rs.generateAuthentication(signedRequest, timestamp)
  173. if err != nil {
  174. return err
  175. }
  176. signedRequest.Header.Set("Authorization", authorizationCode)
  177. req.Header = signedRequest.Header
  178. return nil
  179. }
  180. func (rs *awsRequestSigner) generateAuthentication(req *http.Request, timestamp time.Time) (string, error) {
  181. canonicalHeaderColumns, canonicalHeaderData := canonicalHeaders(req)
  182. dateStamp := timestamp.Format(awsTimeFormatShort)
  183. serviceName := ""
  184. if splitHost := strings.Split(requestHost(req), "."); len(splitHost) > 0 {
  185. serviceName = splitHost[0]
  186. }
  187. credentialScope := fmt.Sprintf("%s/%s/%s/%s", dateStamp, rs.RegionName, serviceName, awsRequestType)
  188. requestString, err := canonicalRequest(req, canonicalHeaderColumns, canonicalHeaderData)
  189. if err != nil {
  190. return "", err
  191. }
  192. requestHash, err := getSha256([]byte(requestString))
  193. if err != nil {
  194. return "", err
  195. }
  196. stringToSign := fmt.Sprintf("%s\n%s\n%s\n%s", awsAlgorithm, timestamp.Format(awsTimeFormatLong), credentialScope, requestHash)
  197. signingKey := []byte("AWS4" + rs.AwsSecurityCredentials.SecretAccessKey)
  198. for _, signingInput := range []string{
  199. dateStamp, rs.RegionName, serviceName, awsRequestType, stringToSign,
  200. } {
  201. signingKey, err = getHmacSha256(signingKey, []byte(signingInput))
  202. if err != nil {
  203. return "", err
  204. }
  205. }
  206. return fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s", awsAlgorithm, rs.AwsSecurityCredentials.AccessKeyID, credentialScope, canonicalHeaderColumns, hex.EncodeToString(signingKey)), nil
  207. }
  208. type awsCredentialSource struct {
  209. EnvironmentID string
  210. RegionURL string
  211. RegionalCredVerificationURL string
  212. CredVerificationURL string
  213. IMDSv2SessionTokenURL string
  214. TargetResource string
  215. requestSigner *awsRequestSigner
  216. region string
  217. ctx context.Context
  218. client *http.Client
  219. }
  220. type awsRequestHeader struct {
  221. Key string `json:"key"`
  222. Value string `json:"value"`
  223. }
  224. type awsRequest struct {
  225. URL string `json:"url"`
  226. Method string `json:"method"`
  227. Headers []awsRequestHeader `json:"headers"`
  228. }
  229. func (cs awsCredentialSource) validateMetadataServers() error {
  230. if err := cs.validateMetadataServer(cs.RegionURL, "region_url"); err != nil {
  231. return err
  232. }
  233. if err := cs.validateMetadataServer(cs.CredVerificationURL, "url"); err != nil {
  234. return err
  235. }
  236. return cs.validateMetadataServer(cs.IMDSv2SessionTokenURL, "imdsv2_session_token_url")
  237. }
  238. var validHostnames []string = []string{"169.254.169.254", "fd00:ec2::254"}
  239. func (cs awsCredentialSource) isValidMetadataServer(metadataUrl string) bool {
  240. if metadataUrl == "" {
  241. // Zero value means use default, which is valid.
  242. return true
  243. }
  244. u, err := url.Parse(metadataUrl)
  245. if err != nil {
  246. // Unparseable URL means invalid
  247. return false
  248. }
  249. for _, validHostname := range validHostnames {
  250. if u.Hostname() == validHostname {
  251. // If it's one of the valid hostnames, everything is good
  252. return true
  253. }
  254. }
  255. // hostname not found in our allowlist, so not valid
  256. return false
  257. }
  258. func (cs awsCredentialSource) validateMetadataServer(metadataUrl, urlName string) error {
  259. if !cs.isValidMetadataServer(metadataUrl) {
  260. return fmt.Errorf("oauth2/google: invalid hostname %s for %s", metadataUrl, urlName)
  261. }
  262. return nil
  263. }
  264. func (cs awsCredentialSource) doRequest(req *http.Request) (*http.Response, error) {
  265. if cs.client == nil {
  266. cs.client = oauth2.NewClient(cs.ctx, nil)
  267. }
  268. return cs.client.Do(req.WithContext(cs.ctx))
  269. }
  270. func canRetrieveRegionFromEnvironment() bool {
  271. // The AWS region can be provided through AWS_REGION or AWS_DEFAULT_REGION. Only one is
  272. // required.
  273. return getenv(awsRegion) != "" || getenv(awsDefaultRegion) != ""
  274. }
  275. func canRetrieveSecurityCredentialFromEnvironment() bool {
  276. // Check if both AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY are available.
  277. return getenv(awsAccessKeyId) != "" && getenv(awsSecretAccessKey) != ""
  278. }
  279. func shouldUseMetadataServer() bool {
  280. return !canRetrieveRegionFromEnvironment() || !canRetrieveSecurityCredentialFromEnvironment()
  281. }
  282. func (cs awsCredentialSource) subjectToken() (string, error) {
  283. if cs.requestSigner == nil {
  284. headers := make(map[string]string)
  285. if shouldUseMetadataServer() {
  286. awsSessionToken, err := cs.getAWSSessionToken()
  287. if err != nil {
  288. return "", err
  289. }
  290. if awsSessionToken != "" {
  291. headers[awsIMDSv2SessionTokenHeader] = awsSessionToken
  292. }
  293. }
  294. awsSecurityCredentials, err := cs.getSecurityCredentials(headers)
  295. if err != nil {
  296. return "", err
  297. }
  298. if cs.region, err = cs.getRegion(headers); err != nil {
  299. return "", err
  300. }
  301. cs.requestSigner = &awsRequestSigner{
  302. RegionName: cs.region,
  303. AwsSecurityCredentials: awsSecurityCredentials,
  304. }
  305. }
  306. // Generate the signed request to AWS STS GetCallerIdentity API.
  307. // Use the required regional endpoint. Otherwise, the request will fail.
  308. req, err := http.NewRequest("POST", strings.Replace(cs.RegionalCredVerificationURL, "{region}", cs.region, 1), nil)
  309. if err != nil {
  310. return "", err
  311. }
  312. // The full, canonical resource name of the workload identity pool
  313. // provider, with or without the HTTPS prefix.
  314. // Including this header as part of the signature is recommended to
  315. // ensure data integrity.
  316. if cs.TargetResource != "" {
  317. req.Header.Add("x-goog-cloud-target-resource", cs.TargetResource)
  318. }
  319. cs.requestSigner.SignRequest(req)
  320. /*
  321. The GCP STS endpoint expects the headers to be formatted as:
  322. # [
  323. # {key: 'x-amz-date', value: '...'},
  324. # {key: 'Authorization', value: '...'},
  325. # ...
  326. # ]
  327. # And then serialized as:
  328. # quote(json.dumps({
  329. # url: '...',
  330. # method: 'POST',
  331. # headers: [{key: 'x-amz-date', value: '...'}, ...]
  332. # }))
  333. */
  334. awsSignedReq := awsRequest{
  335. URL: req.URL.String(),
  336. Method: "POST",
  337. }
  338. for headerKey, headerList := range req.Header {
  339. for _, headerValue := range headerList {
  340. awsSignedReq.Headers = append(awsSignedReq.Headers, awsRequestHeader{
  341. Key: headerKey,
  342. Value: headerValue,
  343. })
  344. }
  345. }
  346. sort.Slice(awsSignedReq.Headers, func(i, j int) bool {
  347. headerCompare := strings.Compare(awsSignedReq.Headers[i].Key, awsSignedReq.Headers[j].Key)
  348. if headerCompare == 0 {
  349. return strings.Compare(awsSignedReq.Headers[i].Value, awsSignedReq.Headers[j].Value) < 0
  350. }
  351. return headerCompare < 0
  352. })
  353. result, err := json.Marshal(awsSignedReq)
  354. if err != nil {
  355. return "", err
  356. }
  357. return url.QueryEscape(string(result)), nil
  358. }
  359. func (cs *awsCredentialSource) getAWSSessionToken() (string, error) {
  360. if cs.IMDSv2SessionTokenURL == "" {
  361. return "", nil
  362. }
  363. req, err := http.NewRequest("PUT", cs.IMDSv2SessionTokenURL, nil)
  364. if err != nil {
  365. return "", err
  366. }
  367. req.Header.Add(awsIMDSv2SessionTtlHeader, awsIMDSv2SessionTtl)
  368. resp, err := cs.doRequest(req)
  369. if err != nil {
  370. return "", err
  371. }
  372. defer resp.Body.Close()
  373. respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
  374. if err != nil {
  375. return "", err
  376. }
  377. if resp.StatusCode != 200 {
  378. return "", fmt.Errorf("oauth2/google: unable to retrieve AWS session token - %s", string(respBody))
  379. }
  380. return string(respBody), nil
  381. }
  382. func (cs *awsCredentialSource) getRegion(headers map[string]string) (string, error) {
  383. if canRetrieveRegionFromEnvironment() {
  384. if envAwsRegion := getenv(awsRegion); envAwsRegion != "" {
  385. return envAwsRegion, nil
  386. }
  387. return getenv("AWS_DEFAULT_REGION"), nil
  388. }
  389. if cs.RegionURL == "" {
  390. return "", errors.New("oauth2/google: unable to determine AWS region")
  391. }
  392. req, err := http.NewRequest("GET", cs.RegionURL, nil)
  393. if err != nil {
  394. return "", err
  395. }
  396. for name, value := range headers {
  397. req.Header.Add(name, value)
  398. }
  399. resp, err := cs.doRequest(req)
  400. if err != nil {
  401. return "", err
  402. }
  403. defer resp.Body.Close()
  404. respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
  405. if err != nil {
  406. return "", err
  407. }
  408. if resp.StatusCode != 200 {
  409. return "", fmt.Errorf("oauth2/google: unable to retrieve AWS region - %s", string(respBody))
  410. }
  411. // This endpoint will return the region in format: us-east-2b.
  412. // Only the us-east-2 part should be used.
  413. respBodyEnd := 0
  414. if len(respBody) > 1 {
  415. respBodyEnd = len(respBody) - 1
  416. }
  417. return string(respBody[:respBodyEnd]), nil
  418. }
  419. func (cs *awsCredentialSource) getSecurityCredentials(headers map[string]string) (result awsSecurityCredentials, err error) {
  420. if canRetrieveSecurityCredentialFromEnvironment() {
  421. return awsSecurityCredentials{
  422. AccessKeyID: getenv(awsAccessKeyId),
  423. SecretAccessKey: getenv(awsSecretAccessKey),
  424. SecurityToken: getenv(awsSessionToken),
  425. }, nil
  426. }
  427. roleName, err := cs.getMetadataRoleName(headers)
  428. if err != nil {
  429. return
  430. }
  431. credentials, err := cs.getMetadataSecurityCredentials(roleName, headers)
  432. if err != nil {
  433. return
  434. }
  435. if credentials.AccessKeyID == "" {
  436. return result, errors.New("oauth2/google: missing AccessKeyId credential")
  437. }
  438. if credentials.SecretAccessKey == "" {
  439. return result, errors.New("oauth2/google: missing SecretAccessKey credential")
  440. }
  441. return credentials, nil
  442. }
  443. func (cs *awsCredentialSource) getMetadataSecurityCredentials(roleName string, headers map[string]string) (awsSecurityCredentials, error) {
  444. var result awsSecurityCredentials
  445. req, err := http.NewRequest("GET", fmt.Sprintf("%s/%s", cs.CredVerificationURL, roleName), nil)
  446. if err != nil {
  447. return result, err
  448. }
  449. req.Header.Add("Content-Type", "application/json")
  450. for name, value := range headers {
  451. req.Header.Add(name, value)
  452. }
  453. resp, err := cs.doRequest(req)
  454. if err != nil {
  455. return result, err
  456. }
  457. defer resp.Body.Close()
  458. respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
  459. if err != nil {
  460. return result, err
  461. }
  462. if resp.StatusCode != 200 {
  463. return result, fmt.Errorf("oauth2/google: unable to retrieve AWS security credentials - %s", string(respBody))
  464. }
  465. err = json.Unmarshal(respBody, &result)
  466. return result, err
  467. }
  468. func (cs *awsCredentialSource) getMetadataRoleName(headers map[string]string) (string, error) {
  469. if cs.CredVerificationURL == "" {
  470. return "", errors.New("oauth2/google: unable to determine the AWS metadata server security credentials endpoint")
  471. }
  472. req, err := http.NewRequest("GET", cs.CredVerificationURL, nil)
  473. if err != nil {
  474. return "", err
  475. }
  476. for name, value := range headers {
  477. req.Header.Add(name, value)
  478. }
  479. resp, err := cs.doRequest(req)
  480. if err != nil {
  481. return "", err
  482. }
  483. defer resp.Body.Close()
  484. respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
  485. if err != nil {
  486. return "", err
  487. }
  488. if resp.StatusCode != 200 {
  489. return "", fmt.Errorf("oauth2/google: unable to retrieve AWS role name - %s", string(respBody))
  490. }
  491. return string(respBody), nil
  492. }