hijack_test.go 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. package client
  2. import (
  3. "context"
  4. "fmt"
  5. "io"
  6. "net"
  7. "net/http"
  8. "net/http/httptest"
  9. "net/url"
  10. "testing"
  11. "github.com/docker/docker/api/server/httputils"
  12. "github.com/docker/docker/api/types"
  13. "github.com/pkg/errors"
  14. "gotest.tools/v3/assert"
  15. )
  16. func TestTLSCloseWriter(t *testing.T) {
  17. t.Parallel()
  18. var chErr chan error
  19. ts := &httptest.Server{Config: &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
  20. chErr = make(chan error, 1)
  21. defer close(chErr)
  22. if err := httputils.ParseForm(req); err != nil {
  23. chErr <- errors.Wrap(err, "error parsing form")
  24. http.Error(w, err.Error(), http.StatusInternalServerError)
  25. return
  26. }
  27. r, rw, err := httputils.HijackConnection(w)
  28. if err != nil {
  29. chErr <- errors.Wrap(err, "error hijacking connection")
  30. http.Error(w, err.Error(), http.StatusInternalServerError)
  31. return
  32. }
  33. defer r.Close()
  34. fmt.Fprint(rw, "HTTP/1.1 101 UPGRADED\r\nContent-Type: application/vnd.docker.raw-stream\r\nConnection: Upgrade\r\nUpgrade: tcp\r\n\n")
  35. buf := make([]byte, 5)
  36. _, err = r.Read(buf)
  37. if err != nil {
  38. chErr <- errors.Wrap(err, "error reading from client")
  39. return
  40. }
  41. _, err = rw.Write(buf)
  42. if err != nil {
  43. chErr <- errors.Wrap(err, "error writing to client")
  44. return
  45. }
  46. })}}
  47. var (
  48. l net.Listener
  49. err error
  50. )
  51. for i := 1024; i < 10000; i++ {
  52. l, err = net.Listen("tcp4", fmt.Sprintf("127.0.0.1:%d", i))
  53. if err == nil {
  54. break
  55. }
  56. }
  57. assert.NilError(t, err)
  58. ts.Listener = l
  59. defer l.Close()
  60. defer func() {
  61. if chErr != nil {
  62. assert.Assert(t, <-chErr)
  63. }
  64. }()
  65. ts.StartTLS()
  66. defer ts.Close()
  67. serverURL, err := url.Parse(ts.URL)
  68. assert.NilError(t, err)
  69. client, err := NewClientWithOpts(WithHost("tcp://"+serverURL.Host), WithHTTPClient(ts.Client()))
  70. assert.NilError(t, err)
  71. resp, err := client.postHijacked(context.Background(), "/asdf", url.Values{}, nil, map[string][]string{"Content-Type": {"text/plain"}})
  72. assert.NilError(t, err)
  73. defer resp.Close()
  74. if _, ok := resp.Conn.(types.CloseWriter); !ok {
  75. t.Fatal("tls conn did not implement the CloseWrite interface")
  76. }
  77. _, err = resp.Conn.Write([]byte("hello"))
  78. assert.NilError(t, err)
  79. b, err := io.ReadAll(resp.Reader)
  80. assert.NilError(t, err)
  81. assert.Assert(t, string(b) == "hello")
  82. assert.Assert(t, resp.CloseWrite())
  83. // This should error since writes are closed
  84. _, err = resp.Conn.Write([]byte("no"))
  85. assert.Assert(t, err != nil)
  86. }