request_middleware.go 7.4 KB


  1. package imds
  2. import (
  3. "bytes"
  4. "context"
  5. "fmt"
  6. "io/ioutil"
  7. "net/url"
  8. "path"
  9. "time"
  10. awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware"
  11. "github.com/aws/aws-sdk-go-v2/aws/retry"
  12. "github.com/aws/smithy-go/middleware"
  13. smithyhttp "github.com/aws/smithy-go/transport/http"
  14. )
  15. func addAPIRequestMiddleware(stack *middleware.Stack,
  16. options Options,
  17. getPath func(interface{}) (string, error),
  18. getOutput func(*smithyhttp.Response) (interface{}, error),
  19. ) (err error) {
  20. err = addRequestMiddleware(stack, options, "GET", getPath, getOutput)
  21. if err != nil {
  22. return err
  23. }
  24. // Token Serializer build and state management.
  25. if !options.disableAPIToken {
  26. err = stack.Finalize.Insert(options.tokenProvider, (*retry.Attempt)(nil).ID(), middleware.After)
  27. if err != nil {
  28. return err
  29. }
  30. err = stack.Deserialize.Insert(options.tokenProvider, "OperationDeserializer", middleware.Before)
  31. if err != nil {
  32. return err
  33. }
  34. }
  35. return nil
  36. }
  37. func addRequestMiddleware(stack *middleware.Stack,
  38. options Options,
  39. method string,
  40. getPath func(interface{}) (string, error),
  41. getOutput func(*smithyhttp.Response) (interface{}, error),
  42. ) (err error) {
  43. err = awsmiddleware.AddSDKAgentKey(awsmiddleware.FeatureMetadata, "ec2-imds")(stack)
  44. if err != nil {
  45. return err
  46. }
  47. // Operation timeout
  48. err = stack.Initialize.Add(&operationTimeout{
  49. DefaultTimeout: defaultOperationTimeout,
  50. }, middleware.Before)
  51. if err != nil {
  52. return err
  53. }
  54. // Operation Serializer
  55. err = stack.Serialize.Add(&serializeRequest{
  56. GetPath: getPath,
  57. Method: method,
  58. }, middleware.After)
  59. if err != nil {
  60. return err
  61. }
  62. // Operation endpoint resolver
  63. err = stack.Serialize.Insert(&resolveEndpoint{
  64. Endpoint: options.Endpoint,
  65. EndpointMode: options.EndpointMode,
  66. }, "OperationSerializer", middleware.Before)
  67. if err != nil {
  68. return err
  69. }
  70. // Operation Deserializer
  71. err = stack.Deserialize.Add(&deserializeResponse{
  72. GetOutput: getOutput,
  73. }, middleware.After)
  74. if err != nil {
  75. return err
  76. }
  77. err = stack.Deserialize.Add(&smithyhttp.RequestResponseLogger{
  78. LogRequest: options.ClientLogMode.IsRequest(),
  79. LogRequestWithBody: options.ClientLogMode.IsRequestWithBody(),
  80. LogResponse: options.ClientLogMode.IsResponse(),
  81. LogResponseWithBody: options.ClientLogMode.IsResponseWithBody(),
  82. }, middleware.After)
  83. if err != nil {
  84. return err
  85. }
  86. err = addSetLoggerMiddleware(stack, options)
  87. if err != nil {
  88. return err
  89. }
  90. // Retry support
  91. return retry.AddRetryMiddlewares(stack, retry.AddRetryMiddlewaresOptions{
  92. Retryer: options.Retryer,
  93. LogRetryAttempts: options.ClientLogMode.IsRetries(),
  94. })
  95. }
  96. func addSetLoggerMiddleware(stack *middleware.Stack, o Options) error {
  97. return middleware.AddSetLoggerMiddleware(stack, o.Logger)
  98. }
  99. type serializeRequest struct {
  100. GetPath func(interface{}) (string, error)
  101. Method string
  102. }
  103. func (*serializeRequest) ID() string {
  104. return "OperationSerializer"
  105. }
  106. func (m *serializeRequest) HandleSerialize(
  107. ctx context.Context, in middleware.SerializeInput, next middleware.SerializeHandler,
  108. ) (
  109. out middleware.SerializeOutput, metadata middleware.Metadata, err error,
  110. ) {
  111. request, ok := in.Request.(*smithyhttp.Request)
  112. if !ok {
  113. return out, metadata, fmt.Errorf("unknown transport type %T", in.Request)
  114. }
  115. reqPath, err := m.GetPath(in.Parameters)
  116. if err != nil {
  117. return out, metadata, fmt.Errorf("unable to get request URL path, %w", err)
  118. }
  119. request.Request.URL.Path = reqPath
  120. request.Request.Method = m.Method
  121. return next.HandleSerialize(ctx, in)
  122. }
  123. type deserializeResponse struct {
  124. GetOutput func(*smithyhttp.Response) (interface{}, error)
  125. }
  126. func (*deserializeResponse) ID() string {
  127. return "OperationDeserializer"
  128. }
  129. func (m *deserializeResponse) HandleDeserialize(
  130. ctx context.Context, in middleware.DeserializeInput, next middleware.DeserializeHandler,
  131. ) (
  132. out middleware.DeserializeOutput, metadata middleware.Metadata, err error,
  133. ) {
  134. out, metadata, err = next.HandleDeserialize(ctx, in)
  135. if err != nil {
  136. return out, metadata, err
  137. }
  138. resp, ok := out.RawResponse.(*smithyhttp.Response)
  139. if !ok {
  140. return out, metadata, fmt.Errorf(
  141. "unexpected transport response type, %T, want %T", out.RawResponse, resp)
  142. }
  143. defer resp.Body.Close()
  144. // read the full body so that any operation timeouts cleanup will not race
  145. // the body being read.
  146. body, err := ioutil.ReadAll(resp.Body)
  147. if err != nil {
  148. return out, metadata, fmt.Errorf("read response body failed, %w", err)
  149. }
  150. resp.Body = ioutil.NopCloser(bytes.NewReader(body))
  151. // Anything that's not 200 |< 300 is error
  152. if resp.StatusCode < 200 || resp.StatusCode >= 300 {
  153. return out, metadata, &smithyhttp.ResponseError{
  154. Response: resp,
  155. Err: fmt.Errorf("request to EC2 IMDS failed"),
  156. }
  157. }
  158. result, err := m.GetOutput(resp)
  159. if err != nil {
  160. return out, metadata, fmt.Errorf(
  161. "unable to get deserialized result for response, %w", err,
  162. )
  163. }
  164. out.Result = result
  165. return out, metadata, err
  166. }
  167. type resolveEndpoint struct {
  168. Endpoint string
  169. EndpointMode EndpointModeState
  170. }
  171. func (*resolveEndpoint) ID() string {
  172. return "ResolveEndpoint"
  173. }
  174. func (m *resolveEndpoint) HandleSerialize(
  175. ctx context.Context, in middleware.SerializeInput, next middleware.SerializeHandler,
  176. ) (
  177. out middleware.SerializeOutput, metadata middleware.Metadata, err error,
  178. ) {
  179. req, ok := in.Request.(*smithyhttp.Request)
  180. if !ok {
  181. return out, metadata, fmt.Errorf("unknown transport type %T", in.Request)
  182. }
  183. var endpoint string
  184. if len(m.Endpoint) > 0 {
  185. endpoint = m.Endpoint
  186. } else {
  187. switch m.EndpointMode {
  188. case EndpointModeStateIPv6:
  189. endpoint = defaultIPv6Endpoint
  190. case EndpointModeStateIPv4:
  191. fallthrough
  192. case EndpointModeStateUnset:
  193. endpoint = defaultIPv4Endpoint
  194. default:
  195. return out, metadata, fmt.Errorf("unsupported IMDS endpoint mode")
  196. }
  197. }
  198. req.URL, err = url.Parse(endpoint)
  199. if err != nil {
  200. return out, metadata, fmt.Errorf("failed to parse endpoint URL: %w", err)
  201. }
  202. return next.HandleSerialize(ctx, in)
  203. }
  204. const (
  205. defaultOperationTimeout = 5 * time.Second
  206. )
  207. // operationTimeout adds a timeout on the middleware stack if the Context the
  208. // stack was called with does not have a deadline. The next middleware must
  209. // complete before the timeout, or the context will be canceled.
  210. //
  211. // If DefaultTimeout is zero, no default timeout will be used if the Context
  212. // does not have a timeout.
  213. //
  214. // The next middleware must also ensure that any resources that are also
  215. // canceled by the stack's context are completely consumed before returning.
  216. // Otherwise the timeout cleanup will race the resource being consumed
  217. // upstream.
  218. type operationTimeout struct {
  219. DefaultTimeout time.Duration
  220. }
  221. func (*operationTimeout) ID() string { return "OperationTimeout" }
  222. func (m *operationTimeout) HandleInitialize(
  223. ctx context.Context, input middleware.InitializeInput, next middleware.InitializeHandler,
  224. ) (
  225. output middleware.InitializeOutput, metadata middleware.Metadata, err error,
  226. ) {
  227. if _, ok := ctx.Deadline(); !ok && m.DefaultTimeout != 0 {
  228. var cancelFn func()
  229. ctx, cancelFn = context.WithTimeout(ctx, m.DefaultTimeout)
  230. defer cancelFn()
  231. }
  232. return next.HandleInitialize(ctx, input)
  233. }
  234. // appendURIPath joins a URI path component to the existing path with `/`
  235. // separators between the path components. If the path being added ends with a
  236. // trailing `/` that slash will be maintained.
  237. func appendURIPath(base, add string) string {
  238. reqPath := path.Join(base, add)
  239. if len(add) != 0 && add[len(add)-1] == '/' {
  240. reqPath += "/"
  241. }
  242. return reqPath
  243. }