middleware.go 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. package eventstreamapi
  2. import (
  3. "context"
  4. "fmt"
  5. "github.com/aws/smithy-go/middleware"
  6. smithyhttp "github.com/aws/smithy-go/transport/http"
  7. "io"
  8. )
  9. type eventStreamWriterKey struct{}
  10. // GetInputStreamWriter returns EventTypeHeader io.PipeWriter used for the operation's input event stream.
  11. func GetInputStreamWriter(ctx context.Context) io.WriteCloser {
  12. writeCloser, _ := middleware.GetStackValue(ctx, eventStreamWriterKey{}).(io.WriteCloser)
  13. return writeCloser
  14. }
  15. func setInputStreamWriter(ctx context.Context, writeCloser io.WriteCloser) context.Context {
  16. return middleware.WithStackValue(ctx, eventStreamWriterKey{}, writeCloser)
  17. }
  18. // InitializeStreamWriter is a Finalize middleware initializes an in-memory pipe for sending event stream messages
  19. // via the HTTP request body.
  20. type InitializeStreamWriter struct{}
  21. // AddInitializeStreamWriter adds the InitializeStreamWriter middleware to the provided stack.
  22. func AddInitializeStreamWriter(stack *middleware.Stack) error {
  23. return stack.Finalize.Add(&InitializeStreamWriter{}, middleware.After)
  24. }
  25. // ID returns the identifier for the middleware.
  26. func (i *InitializeStreamWriter) ID() string {
  27. return "InitializeStreamWriter"
  28. }
  29. // HandleFinalize is the middleware implementation.
  30. func (i *InitializeStreamWriter) HandleFinalize(
  31. ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler,
  32. ) (
  33. out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
  34. ) {
  35. request, ok := in.Request.(*smithyhttp.Request)
  36. if !ok {
  37. return out, metadata, fmt.Errorf("unknown transport type: %T", in.Request)
  38. }
  39. inputReader, inputWriter := io.Pipe()
  40. defer func() {
  41. if err == nil {
  42. return
  43. }
  44. _ = inputReader.Close()
  45. _ = inputWriter.Close()
  46. }()
  47. request, err = request.SetStream(inputReader)
  48. if err != nil {
  49. return out, metadata, err
  50. }
  51. in.Request = request
  52. ctx = setInputStreamWriter(ctx, inputWriter)
  53. out, metadata, err = next.HandleFinalize(ctx, in)
  54. if err != nil {
  55. return out, metadata, err
  56. }
  57. return out, metadata, err
  58. }