aws.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470
  1. // Copyright 2021 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package externalaccount
  5. import (
  6. "bytes"
  7. "context"
  8. "crypto/hmac"
  9. "crypto/sha256"
  10. "encoding/hex"
  11. "encoding/json"
  12. "errors"
  13. "fmt"
  14. "io"
  15. "io/ioutil"
  16. "net/http"
  17. "net/url"
  18. "os"
  19. "path"
  20. "sort"
  21. "strings"
  22. "time"
  23. "golang.org/x/oauth2"
  24. )
  25. type awsSecurityCredentials struct {
  26. AccessKeyID string `json:"AccessKeyID"`
  27. SecretAccessKey string `json:"SecretAccessKey"`
  28. SecurityToken string `json:"Token"`
  29. }
  30. // awsRequestSigner is a utility class to sign http requests using a AWS V4 signature.
  31. type awsRequestSigner struct {
  32. RegionName string
  33. AwsSecurityCredentials awsSecurityCredentials
  34. }
  35. // getenv aliases os.Getenv for testing
  36. var getenv = os.Getenv
  37. const (
  38. // AWS Signature Version 4 signing algorithm identifier.
  39. awsAlgorithm = "AWS4-HMAC-SHA256"
  40. // The termination string for the AWS credential scope value as defined in
  41. // https://docs.aws.amazon.com/general/latest/gr/sigv4-create-string-to-sign.html
  42. awsRequestType = "aws4_request"
  43. // The AWS authorization header name for the security session token if available.
  44. awsSecurityTokenHeader = "x-amz-security-token"
  45. // The AWS authorization header name for the auto-generated date.
  46. awsDateHeader = "x-amz-date"
  47. awsTimeFormatLong = "20060102T150405Z"
  48. awsTimeFormatShort = "20060102"
  49. )
  50. func getSha256(input []byte) (string, error) {
  51. hash := sha256.New()
  52. if _, err := hash.Write(input); err != nil {
  53. return "", err
  54. }
  55. return hex.EncodeToString(hash.Sum(nil)), nil
  56. }
  57. func getHmacSha256(key, input []byte) ([]byte, error) {
  58. hash := hmac.New(sha256.New, key)
  59. if _, err := hash.Write(input); err != nil {
  60. return nil, err
  61. }
  62. return hash.Sum(nil), nil
  63. }
  64. func cloneRequest(r *http.Request) *http.Request {
  65. r2 := new(http.Request)
  66. *r2 = *r
  67. if r.Header != nil {
  68. r2.Header = make(http.Header, len(r.Header))
  69. // Find total number of values.
  70. headerCount := 0
  71. for _, headerValues := range r.Header {
  72. headerCount += len(headerValues)
  73. }
  74. copiedHeaders := make([]string, headerCount) // shared backing array for headers' values
  75. for headerKey, headerValues := range r.Header {
  76. headerCount = copy(copiedHeaders, headerValues)
  77. r2.Header[headerKey] = copiedHeaders[:headerCount:headerCount]
  78. copiedHeaders = copiedHeaders[headerCount:]
  79. }
  80. }
  81. return r2
  82. }
  83. func canonicalPath(req *http.Request) string {
  84. result := req.URL.EscapedPath()
  85. if result == "" {
  86. return "/"
  87. }
  88. return path.Clean(result)
  89. }
  90. func canonicalQuery(req *http.Request) string {
  91. queryValues := req.URL.Query()
  92. for queryKey := range queryValues {
  93. sort.Strings(queryValues[queryKey])
  94. }
  95. return queryValues.Encode()
  96. }
  97. func canonicalHeaders(req *http.Request) (string, string) {
  98. // Header keys need to be sorted alphabetically.
  99. var headers []string
  100. lowerCaseHeaders := make(http.Header)
  101. for k, v := range req.Header {
  102. k := strings.ToLower(k)
  103. if _, ok := lowerCaseHeaders[k]; ok {
  104. // include additional values
  105. lowerCaseHeaders[k] = append(lowerCaseHeaders[k], v...)
  106. } else {
  107. headers = append(headers, k)
  108. lowerCaseHeaders[k] = v
  109. }
  110. }
  111. sort.Strings(headers)
  112. var fullHeaders bytes.Buffer
  113. for _, header := range headers {
  114. headerValue := strings.Join(lowerCaseHeaders[header], ",")
  115. fullHeaders.WriteString(header)
  116. fullHeaders.WriteRune(':')
  117. fullHeaders.WriteString(headerValue)
  118. fullHeaders.WriteRune('\n')
  119. }
  120. return strings.Join(headers, ";"), fullHeaders.String()
  121. }
  122. func requestDataHash(req *http.Request) (string, error) {
  123. var requestData []byte
  124. if req.Body != nil {
  125. requestBody, err := req.GetBody()
  126. if err != nil {
  127. return "", err
  128. }
  129. defer requestBody.Close()
  130. requestData, err = ioutil.ReadAll(io.LimitReader(requestBody, 1<<20))
  131. if err != nil {
  132. return "", err
  133. }
  134. }
  135. return getSha256(requestData)
  136. }
  137. func requestHost(req *http.Request) string {
  138. if req.Host != "" {
  139. return req.Host
  140. }
  141. return req.URL.Host
  142. }
  143. func canonicalRequest(req *http.Request, canonicalHeaderColumns, canonicalHeaderData string) (string, error) {
  144. dataHash, err := requestDataHash(req)
  145. if err != nil {
  146. return "", err
  147. }
  148. return fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s", req.Method, canonicalPath(req), canonicalQuery(req), canonicalHeaderData, canonicalHeaderColumns, dataHash), nil
  149. }
  150. // SignRequest adds the appropriate headers to an http.Request
  151. // or returns an error if something prevented this.
  152. func (rs *awsRequestSigner) SignRequest(req *http.Request) error {
  153. signedRequest := cloneRequest(req)
  154. timestamp := now()
  155. signedRequest.Header.Add("host", requestHost(req))
  156. if rs.AwsSecurityCredentials.SecurityToken != "" {
  157. signedRequest.Header.Add(awsSecurityTokenHeader, rs.AwsSecurityCredentials.SecurityToken)
  158. }
  159. if signedRequest.Header.Get("date") == "" {
  160. signedRequest.Header.Add(awsDateHeader, timestamp.Format(awsTimeFormatLong))
  161. }
  162. authorizationCode, err := rs.generateAuthentication(signedRequest, timestamp)
  163. if err != nil {
  164. return err
  165. }
  166. signedRequest.Header.Set("Authorization", authorizationCode)
  167. req.Header = signedRequest.Header
  168. return nil
  169. }
  170. func (rs *awsRequestSigner) generateAuthentication(req *http.Request, timestamp time.Time) (string, error) {
  171. canonicalHeaderColumns, canonicalHeaderData := canonicalHeaders(req)
  172. dateStamp := timestamp.Format(awsTimeFormatShort)
  173. serviceName := ""
  174. if splitHost := strings.Split(requestHost(req), "."); len(splitHost) > 0 {
  175. serviceName = splitHost[0]
  176. }
  177. credentialScope := fmt.Sprintf("%s/%s/%s/%s", dateStamp, rs.RegionName, serviceName, awsRequestType)
  178. requestString, err := canonicalRequest(req, canonicalHeaderColumns, canonicalHeaderData)
  179. if err != nil {
  180. return "", err
  181. }
  182. requestHash, err := getSha256([]byte(requestString))
  183. if err != nil {
  184. return "", err
  185. }
  186. stringToSign := fmt.Sprintf("%s\n%s\n%s\n%s", awsAlgorithm, timestamp.Format(awsTimeFormatLong), credentialScope, requestHash)
  187. signingKey := []byte("AWS4" + rs.AwsSecurityCredentials.SecretAccessKey)
  188. for _, signingInput := range []string{
  189. dateStamp, rs.RegionName, serviceName, awsRequestType, stringToSign,
  190. } {
  191. signingKey, err = getHmacSha256(signingKey, []byte(signingInput))
  192. if err != nil {
  193. return "", err
  194. }
  195. }
  196. return fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s", awsAlgorithm, rs.AwsSecurityCredentials.AccessKeyID, credentialScope, canonicalHeaderColumns, hex.EncodeToString(signingKey)), nil
  197. }
  198. type awsCredentialSource struct {
  199. EnvironmentID string
  200. RegionURL string
  201. RegionalCredVerificationURL string
  202. CredVerificationURL string
  203. TargetResource string
  204. requestSigner *awsRequestSigner
  205. region string
  206. ctx context.Context
  207. client *http.Client
  208. }
  209. type awsRequestHeader struct {
  210. Key string `json:"key"`
  211. Value string `json:"value"`
  212. }
  213. type awsRequest struct {
  214. URL string `json:"url"`
  215. Method string `json:"method"`
  216. Headers []awsRequestHeader `json:"headers"`
  217. }
  218. func (cs awsCredentialSource) doRequest(req *http.Request) (*http.Response, error) {
  219. if cs.client == nil {
  220. cs.client = oauth2.NewClient(cs.ctx, nil)
  221. }
  222. return cs.client.Do(req.WithContext(cs.ctx))
  223. }
  224. func (cs awsCredentialSource) subjectToken() (string, error) {
  225. if cs.requestSigner == nil {
  226. awsSecurityCredentials, err := cs.getSecurityCredentials()
  227. if err != nil {
  228. return "", err
  229. }
  230. if cs.region, err = cs.getRegion(); err != nil {
  231. return "", err
  232. }
  233. cs.requestSigner = &awsRequestSigner{
  234. RegionName: cs.region,
  235. AwsSecurityCredentials: awsSecurityCredentials,
  236. }
  237. }
  238. // Generate the signed request to AWS STS GetCallerIdentity API.
  239. // Use the required regional endpoint. Otherwise, the request will fail.
  240. req, err := http.NewRequest("POST", strings.Replace(cs.RegionalCredVerificationURL, "{region}", cs.region, 1), nil)
  241. if err != nil {
  242. return "", err
  243. }
  244. // The full, canonical resource name of the workload identity pool
  245. // provider, with or without the HTTPS prefix.
  246. // Including this header as part of the signature is recommended to
  247. // ensure data integrity.
  248. if cs.TargetResource != "" {
  249. req.Header.Add("x-goog-cloud-target-resource", cs.TargetResource)
  250. }
  251. cs.requestSigner.SignRequest(req)
  252. /*
  253. The GCP STS endpoint expects the headers to be formatted as:
  254. # [
  255. # {key: 'x-amz-date', value: '...'},
  256. # {key: 'Authorization', value: '...'},
  257. # ...
  258. # ]
  259. # And then serialized as:
  260. # quote(json.dumps({
  261. # url: '...',
  262. # method: 'POST',
  263. # headers: [{key: 'x-amz-date', value: '...'}, ...]
  264. # }))
  265. */
  266. awsSignedReq := awsRequest{
  267. URL: req.URL.String(),
  268. Method: "POST",
  269. }
  270. for headerKey, headerList := range req.Header {
  271. for _, headerValue := range headerList {
  272. awsSignedReq.Headers = append(awsSignedReq.Headers, awsRequestHeader{
  273. Key: headerKey,
  274. Value: headerValue,
  275. })
  276. }
  277. }
  278. sort.Slice(awsSignedReq.Headers, func(i, j int) bool {
  279. headerCompare := strings.Compare(awsSignedReq.Headers[i].Key, awsSignedReq.Headers[j].Key)
  280. if headerCompare == 0 {
  281. return strings.Compare(awsSignedReq.Headers[i].Value, awsSignedReq.Headers[j].Value) < 0
  282. }
  283. return headerCompare < 0
  284. })
  285. result, err := json.Marshal(awsSignedReq)
  286. if err != nil {
  287. return "", err
  288. }
  289. return url.QueryEscape(string(result)), nil
  290. }
  291. func (cs *awsCredentialSource) getRegion() (string, error) {
  292. if envAwsRegion := getenv("AWS_REGION"); envAwsRegion != "" {
  293. return envAwsRegion, nil
  294. }
  295. if envAwsRegion := getenv("AWS_DEFAULT_REGION"); envAwsRegion != "" {
  296. return envAwsRegion, nil
  297. }
  298. if cs.RegionURL == "" {
  299. return "", errors.New("oauth2/google: unable to determine AWS region")
  300. }
  301. req, err := http.NewRequest("GET", cs.RegionURL, nil)
  302. if err != nil {
  303. return "", err
  304. }
  305. resp, err := cs.doRequest(req)
  306. if err != nil {
  307. return "", err
  308. }
  309. defer resp.Body.Close()
  310. respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
  311. if err != nil {
  312. return "", err
  313. }
  314. if resp.StatusCode != 200 {
  315. return "", fmt.Errorf("oauth2/google: unable to retrieve AWS region - %s", string(respBody))
  316. }
  317. // This endpoint will return the region in format: us-east-2b.
  318. // Only the us-east-2 part should be used.
  319. respBodyEnd := 0
  320. if len(respBody) > 1 {
  321. respBodyEnd = len(respBody) - 1
  322. }
  323. return string(respBody[:respBodyEnd]), nil
  324. }
  325. func (cs *awsCredentialSource) getSecurityCredentials() (result awsSecurityCredentials, err error) {
  326. if accessKeyID := getenv("AWS_ACCESS_KEY_ID"); accessKeyID != "" {
  327. if secretAccessKey := getenv("AWS_SECRET_ACCESS_KEY"); secretAccessKey != "" {
  328. return awsSecurityCredentials{
  329. AccessKeyID: accessKeyID,
  330. SecretAccessKey: secretAccessKey,
  331. SecurityToken: getenv("AWS_SESSION_TOKEN"),
  332. }, nil
  333. }
  334. }
  335. roleName, err := cs.getMetadataRoleName()
  336. if err != nil {
  337. return
  338. }
  339. credentials, err := cs.getMetadataSecurityCredentials(roleName)
  340. if err != nil {
  341. return
  342. }
  343. if credentials.AccessKeyID == "" {
  344. return result, errors.New("oauth2/google: missing AccessKeyId credential")
  345. }
  346. if credentials.SecretAccessKey == "" {
  347. return result, errors.New("oauth2/google: missing SecretAccessKey credential")
  348. }
  349. return credentials, nil
  350. }
  351. func (cs *awsCredentialSource) getMetadataSecurityCredentials(roleName string) (awsSecurityCredentials, error) {
  352. var result awsSecurityCredentials
  353. req, err := http.NewRequest("GET", fmt.Sprintf("%s/%s", cs.CredVerificationURL, roleName), nil)
  354. if err != nil {
  355. return result, err
  356. }
  357. req.Header.Add("Content-Type", "application/json")
  358. resp, err := cs.doRequest(req)
  359. if err != nil {
  360. return result, err
  361. }
  362. defer resp.Body.Close()
  363. respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
  364. if err != nil {
  365. return result, err
  366. }
  367. if resp.StatusCode != 200 {
  368. return result, fmt.Errorf("oauth2/google: unable to retrieve AWS security credentials - %s", string(respBody))
  369. }
  370. err = json.Unmarshal(respBody, &result)
  371. return result, err
  372. }
  373. func (cs *awsCredentialSource) getMetadataRoleName() (string, error) {
  374. if cs.CredVerificationURL == "" {
  375. return "", errors.New("oauth2/google: unable to determine the AWS metadata server security credentials endpoint")
  376. }
  377. req, err := http.NewRequest("GET", cs.CredVerificationURL, nil)
  378. if err != nil {
  379. return "", err
  380. }
  381. resp, err := cs.doRequest(req)
  382. if err != nil {
  383. return "", err
  384. }
  385. defer resp.Body.Close()
  386. respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
  387. if err != nil {
  388. return "", err
  389. }
  390. if resp.StatusCode != 200 {
  391. return "", fmt.Errorf("oauth2/google: unable to retrieve AWS role name - %s", string(respBody))
  392. }
  393. return string(respBody), nil
  394. }