request_middleware.go 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. package imds
  2. import (
  3. "bytes"
  4. "context"
  5. "fmt"
  6. "io/ioutil"
  7. "net/url"
  8. "path"
  9. "time"
  10. awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware"
  11. "github.com/aws/aws-sdk-go-v2/aws/retry"
  12. "github.com/aws/smithy-go/middleware"
  13. smithyhttp "github.com/aws/smithy-go/transport/http"
  14. )
  15. func addAPIRequestMiddleware(stack *middleware.Stack,
  16. options Options,
  17. operation string,
  18. getPath func(interface{}) (string, error),
  19. getOutput func(*smithyhttp.Response) (interface{}, error),
  20. ) (err error) {
  21. err = addRequestMiddleware(stack, options, "GET", operation, getPath, getOutput)
  22. if err != nil {
  23. return err
  24. }
  25. // Token Serializer build and state management.
  26. if !options.disableAPIToken {
  27. err = stack.Finalize.Insert(options.tokenProvider, (*retry.Attempt)(nil).ID(), middleware.After)
  28. if err != nil {
  29. return err
  30. }
  31. err = stack.Deserialize.Insert(options.tokenProvider, "OperationDeserializer", middleware.Before)
  32. if err != nil {
  33. return err
  34. }
  35. }
  36. return nil
  37. }
  38. func addRequestMiddleware(stack *middleware.Stack,
  39. options Options,
  40. method string,
  41. operation string,
  42. getPath func(interface{}) (string, error),
  43. getOutput func(*smithyhttp.Response) (interface{}, error),
  44. ) (err error) {
  45. err = awsmiddleware.AddSDKAgentKey(awsmiddleware.FeatureMetadata, "ec2-imds")(stack)
  46. if err != nil {
  47. return err
  48. }
  49. // Operation timeout
  50. err = stack.Initialize.Add(&operationTimeout{
  51. DefaultTimeout: defaultOperationTimeout,
  52. }, middleware.Before)
  53. if err != nil {
  54. return err
  55. }
  56. // Operation Serializer
  57. err = stack.Serialize.Add(&serializeRequest{
  58. GetPath: getPath,
  59. Method: method,
  60. }, middleware.After)
  61. if err != nil {
  62. return err
  63. }
  64. // Operation endpoint resolver
  65. err = stack.Serialize.Insert(&resolveEndpoint{
  66. Endpoint: options.Endpoint,
  67. EndpointMode: options.EndpointMode,
  68. }, "OperationSerializer", middleware.Before)
  69. if err != nil {
  70. return err
  71. }
  72. // Operation Deserializer
  73. err = stack.Deserialize.Add(&deserializeResponse{
  74. GetOutput: getOutput,
  75. }, middleware.After)
  76. if err != nil {
  77. return err
  78. }
  79. err = stack.Deserialize.Add(&smithyhttp.RequestResponseLogger{
  80. LogRequest: options.ClientLogMode.IsRequest(),
  81. LogRequestWithBody: options.ClientLogMode.IsRequestWithBody(),
  82. LogResponse: options.ClientLogMode.IsResponse(),
  83. LogResponseWithBody: options.ClientLogMode.IsResponseWithBody(),
  84. }, middleware.After)
  85. if err != nil {
  86. return err
  87. }
  88. err = addSetLoggerMiddleware(stack, options)
  89. if err != nil {
  90. return err
  91. }
  92. if err := addProtocolFinalizerMiddlewares(stack, options, operation); err != nil {
  93. return fmt.Errorf("add protocol finalizers: %w", err)
  94. }
  95. // Retry support
  96. return retry.AddRetryMiddlewares(stack, retry.AddRetryMiddlewaresOptions{
  97. Retryer: options.Retryer,
  98. LogRetryAttempts: options.ClientLogMode.IsRetries(),
  99. })
  100. }
  101. func addSetLoggerMiddleware(stack *middleware.Stack, o Options) error {
  102. return middleware.AddSetLoggerMiddleware(stack, o.Logger)
  103. }
  104. type serializeRequest struct {
  105. GetPath func(interface{}) (string, error)
  106. Method string
  107. }
  108. func (*serializeRequest) ID() string {
  109. return "OperationSerializer"
  110. }
  111. func (m *serializeRequest) HandleSerialize(
  112. ctx context.Context, in middleware.SerializeInput, next middleware.SerializeHandler,
  113. ) (
  114. out middleware.SerializeOutput, metadata middleware.Metadata, err error,
  115. ) {
  116. request, ok := in.Request.(*smithyhttp.Request)
  117. if !ok {
  118. return out, metadata, fmt.Errorf("unknown transport type %T", in.Request)
  119. }
  120. reqPath, err := m.GetPath(in.Parameters)
  121. if err != nil {
  122. return out, metadata, fmt.Errorf("unable to get request URL path, %w", err)
  123. }
  124. request.Request.URL.Path = reqPath
  125. request.Request.Method = m.Method
  126. return next.HandleSerialize(ctx, in)
  127. }
  128. type deserializeResponse struct {
  129. GetOutput func(*smithyhttp.Response) (interface{}, error)
  130. }
  131. func (*deserializeResponse) ID() string {
  132. return "OperationDeserializer"
  133. }
  134. func (m *deserializeResponse) HandleDeserialize(
  135. ctx context.Context, in middleware.DeserializeInput, next middleware.DeserializeHandler,
  136. ) (
  137. out middleware.DeserializeOutput, metadata middleware.Metadata, err error,
  138. ) {
  139. out, metadata, err = next.HandleDeserialize(ctx, in)
  140. if err != nil {
  141. return out, metadata, err
  142. }
  143. resp, ok := out.RawResponse.(*smithyhttp.Response)
  144. if !ok {
  145. return out, metadata, fmt.Errorf(
  146. "unexpected transport response type, %T, want %T", out.RawResponse, resp)
  147. }
  148. defer resp.Body.Close()
  149. // read the full body so that any operation timeouts cleanup will not race
  150. // the body being read.
  151. body, err := ioutil.ReadAll(resp.Body)
  152. if err != nil {
  153. return out, metadata, fmt.Errorf("read response body failed, %w", err)
  154. }
  155. resp.Body = ioutil.NopCloser(bytes.NewReader(body))
  156. // Anything that's not 200 |< 300 is error
  157. if resp.StatusCode < 200 || resp.StatusCode >= 300 {
  158. return out, metadata, &smithyhttp.ResponseError{
  159. Response: resp,
  160. Err: fmt.Errorf("request to EC2 IMDS failed"),
  161. }
  162. }
  163. result, err := m.GetOutput(resp)
  164. if err != nil {
  165. return out, metadata, fmt.Errorf(
  166. "unable to get deserialized result for response, %w", err,
  167. )
  168. }
  169. out.Result = result
  170. return out, metadata, err
  171. }
  172. type resolveEndpoint struct {
  173. Endpoint string
  174. EndpointMode EndpointModeState
  175. }
  176. func (*resolveEndpoint) ID() string {
  177. return "ResolveEndpoint"
  178. }
  179. func (m *resolveEndpoint) HandleSerialize(
  180. ctx context.Context, in middleware.SerializeInput, next middleware.SerializeHandler,
  181. ) (
  182. out middleware.SerializeOutput, metadata middleware.Metadata, err error,
  183. ) {
  184. req, ok := in.Request.(*smithyhttp.Request)
  185. if !ok {
  186. return out, metadata, fmt.Errorf("unknown transport type %T", in.Request)
  187. }
  188. var endpoint string
  189. if len(m.Endpoint) > 0 {
  190. endpoint = m.Endpoint
  191. } else {
  192. switch m.EndpointMode {
  193. case EndpointModeStateIPv6:
  194. endpoint = defaultIPv6Endpoint
  195. case EndpointModeStateIPv4:
  196. fallthrough
  197. case EndpointModeStateUnset:
  198. endpoint = defaultIPv4Endpoint
  199. default:
  200. return out, metadata, fmt.Errorf("unsupported IMDS endpoint mode")
  201. }
  202. }
  203. req.URL, err = url.Parse(endpoint)
  204. if err != nil {
  205. return out, metadata, fmt.Errorf("failed to parse endpoint URL: %w", err)
  206. }
  207. return next.HandleSerialize(ctx, in)
  208. }
  209. const (
  210. defaultOperationTimeout = 5 * time.Second
  211. )
  212. // operationTimeout adds a timeout on the middleware stack if the Context the
  213. // stack was called with does not have a deadline. The next middleware must
  214. // complete before the timeout, or the context will be canceled.
  215. //
  216. // If DefaultTimeout is zero, no default timeout will be used if the Context
  217. // does not have a timeout.
  218. //
  219. // The next middleware must also ensure that any resources that are also
  220. // canceled by the stack's context are completely consumed before returning.
  221. // Otherwise the timeout cleanup will race the resource being consumed
  222. // upstream.
  223. type operationTimeout struct {
  224. DefaultTimeout time.Duration
  225. }
  226. func (*operationTimeout) ID() string { return "OperationTimeout" }
  227. func (m *operationTimeout) HandleInitialize(
  228. ctx context.Context, input middleware.InitializeInput, next middleware.InitializeHandler,
  229. ) (
  230. output middleware.InitializeOutput, metadata middleware.Metadata, err error,
  231. ) {
  232. if _, ok := ctx.Deadline(); !ok && m.DefaultTimeout != 0 {
  233. var cancelFn func()
  234. ctx, cancelFn = context.WithTimeout(ctx, m.DefaultTimeout)
  235. defer cancelFn()
  236. }
  237. return next.HandleInitialize(ctx, input)
  238. }
  239. // appendURIPath joins a URI path component to the existing path with `/`
  240. // separators between the path components. If the path being added ends with a
  241. // trailing `/` that slash will be maintained.
  242. func appendURIPath(base, add string) string {
  243. reqPath := path.Join(base, add)
  244. if len(add) != 0 && add[len(add)-1] == '/' {
  245. reqPath += "/"
  246. }
  247. return reqPath
  248. }
  249. func addProtocolFinalizerMiddlewares(stack *middleware.Stack, options Options, operation string) error {
  250. if err := stack.Finalize.Add(&resolveAuthSchemeMiddleware{operation: operation, options: options}, middleware.Before); err != nil {
  251. return fmt.Errorf("add ResolveAuthScheme: %w", err)
  252. }
  253. if err := stack.Finalize.Insert(&getIdentityMiddleware{options: options}, "ResolveAuthScheme", middleware.After); err != nil {
  254. return fmt.Errorf("add GetIdentity: %w", err)
  255. }
  256. if err := stack.Finalize.Insert(&resolveEndpointV2Middleware{options: options}, "GetIdentity", middleware.After); err != nil {
  257. return fmt.Errorf("add ResolveEndpointV2: %w", err)
  258. }
  259. if err := stack.Finalize.Insert(&signRequestMiddleware{}, "ResolveEndpointV2", middleware.After); err != nil {
  260. return fmt.Errorf("add Signing: %w", err)
  261. }
  262. return nil
  263. }