Merge pull request #46920 from dmcgowan/client-hijack-cleanup
Replace use of httputil in client hijack
This commit is contained in:
commit
0751141003
3 changed files with 54 additions and 90 deletions
|
@ -113,14 +113,6 @@ issues:
|
|||
path: "api/types/(volume|container)/"
|
||||
linters:
|
||||
- revive
|
||||
# FIXME temporarily suppress these. See #39926
|
||||
- text: "SA1019: httputil.NewClientConn"
|
||||
linters:
|
||||
- staticcheck
|
||||
# FIXME temporarily suppress these (related to the ones above)
|
||||
- text: "SA1019: httputil.ErrPersistEOF"
|
||||
linters:
|
||||
- staticcheck
|
||||
# FIXME temporarily suppress these (see https://github.com/gotestyourself/gotest.tools/issues/272)
|
||||
- text: "SA1019: (assert|cmp|is)\\.ErrorType is deprecated"
|
||||
linters:
|
||||
|
|
|
@ -6,17 +6,13 @@ import (
|
|||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/docker/docker/api/types"
|
||||
"github.com/docker/docker/api/types/versions"
|
||||
"github.com/pkg/errors"
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.opentelemetry.io/otel/codes"
|
||||
"go.opentelemetry.io/otel/propagation"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
|
||||
)
|
||||
|
||||
// postHijacked sends a POST request and hijacks the connection.
|
||||
|
@ -54,33 +50,16 @@ func (cli *Client) setupHijackConn(req *http.Request, proto string) (_ net.Conn,
|
|||
req.Header.Set("Connection", "Upgrade")
|
||||
req.Header.Set("Upgrade", proto)
|
||||
|
||||
// We aren't using the configured RoundTripper here so manually inject the trace context
|
||||
tp := cli.tp
|
||||
if tp == nil {
|
||||
if span := trace.SpanFromContext(ctx); span.SpanContext().IsValid() {
|
||||
tp = span.TracerProvider()
|
||||
} else {
|
||||
tp = otel.GetTracerProvider()
|
||||
}
|
||||
}
|
||||
|
||||
ctx, span := tp.Tracer("").Start(ctx, req.Method+" "+req.URL.Path, trace.WithSpanKind(trace.SpanKindClient))
|
||||
// FIXME(thaJeztah): httpconv.ClientRequest is now an internal package; replace this with alternative for semconv v1.21
|
||||
// span.SetAttributes(httpconv.ClientRequest(req)...)
|
||||
defer func() {
|
||||
if retErr != nil {
|
||||
span.RecordError(retErr)
|
||||
span.SetStatus(codes.Error, retErr.Error())
|
||||
}
|
||||
span.End()
|
||||
}()
|
||||
otel.GetTextMapPropagator().Inject(ctx, propagation.HeaderCarrier(req.Header))
|
||||
|
||||
dialer := cli.Dialer()
|
||||
conn, err := dialer(ctx)
|
||||
if err != nil {
|
||||
return nil, "", errors.Wrap(err, "cannot connect to the Docker daemon. Is 'docker daemon' running on this host?")
|
||||
}
|
||||
defer func() {
|
||||
if retErr != nil {
|
||||
conn.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
// When we set up a TCP connection for hijack, there could be long periods
|
||||
// of inactivity (a long running command with no output) that in certain
|
||||
|
@ -92,58 +71,29 @@ func (cli *Client) setupHijackConn(req *http.Request, proto string) (_ net.Conn,
|
|||
_ = tcpConn.SetKeepAlivePeriod(30 * time.Second)
|
||||
}
|
||||
|
||||
clientconn := httputil.NewClientConn(conn, nil)
|
||||
defer clientconn.Close()
|
||||
hc := &hijackedConn{conn, bufio.NewReader(conn)}
|
||||
|
||||
// Server hijacks the connection, error 'connection closed' expected
|
||||
resp, err := clientconn.Do(req)
|
||||
if resp != nil {
|
||||
// This is a simplified variant of "httpconv.ClientStatus(resp.StatusCode))";
|
||||
//
|
||||
// The main purpose of httpconv.ClientStatus() is to detect whether the
|
||||
// status was successful (1xx, 2xx, 3xx) or non-successful (4xx/5xx).
|
||||
//
|
||||
// It also provides complex logic to *validate* status-codes against
|
||||
// a hard-coded list meant to exclude "bogus" status codes in "success"
|
||||
// ranges (1xx, 2xx) and convert them into an error status. That code
|
||||
// seemed over-reaching (and not accounting for potential future valid
|
||||
// status codes). We assume we only get valid status codes, and only
|
||||
// look at status-code ranges.
|
||||
//
|
||||
// For reference, see:
|
||||
// https://github.com/open-telemetry/opentelemetry-go/blob/v1.21.0/semconv/v1.17.0/httpconv/http.go#L85-L89
|
||||
// https://github.com/open-telemetry/opentelemetry-go/blob/v1.21.0/semconv/internal/v2/http.go#L322-L330
|
||||
// https://github.com/open-telemetry/opentelemetry-go/blob/v1.21.0/semconv/internal/v2/http.go#L356-L404
|
||||
code := codes.Unset
|
||||
if resp.StatusCode >= http.StatusBadRequest {
|
||||
code = codes.Error
|
||||
}
|
||||
span.SetStatus(code, "")
|
||||
resp, err := otelhttp.NewTransport(hc).RoundTrip(req)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if resp.StatusCode != http.StatusSwitchingProtocols {
|
||||
_ = resp.Body.Close()
|
||||
return nil, "", fmt.Errorf("unable to upgrade to %s, received %d", proto, resp.StatusCode)
|
||||
}
|
||||
|
||||
//nolint:staticcheck // ignore SA1019 for connecting to old (pre go1.8) daemons
|
||||
if err != httputil.ErrPersistEOF {
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if resp.StatusCode != http.StatusSwitchingProtocols {
|
||||
_ = resp.Body.Close()
|
||||
return nil, "", fmt.Errorf("unable to upgrade to %s, received %d", proto, resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
c, br := clientconn.Hijack()
|
||||
if br.Buffered() > 0 {
|
||||
if hc.r.Buffered() > 0 {
|
||||
// If there is buffered content, wrap the connection. We return an
|
||||
// object that implements CloseWrite if the underlying connection
|
||||
// implements it.
|
||||
if _, ok := c.(types.CloseWriter); ok {
|
||||
c = &hijackedConnCloseWriter{&hijackedConn{c, br}}
|
||||
if _, ok := hc.Conn.(types.CloseWriter); ok {
|
||||
conn = &hijackedConnCloseWriter{hc}
|
||||
} else {
|
||||
c = &hijackedConn{c, br}
|
||||
conn = hc
|
||||
}
|
||||
} else {
|
||||
br.Reset(nil)
|
||||
hc.r.Reset(nil)
|
||||
}
|
||||
|
||||
var mediaType string
|
||||
|
@ -152,7 +102,7 @@ func (cli *Client) setupHijackConn(req *http.Request, proto string) (_ net.Conn,
|
|||
mediaType = resp.Header.Get("Content-Type")
|
||||
}
|
||||
|
||||
return c, mediaType, nil
|
||||
return conn, mediaType, nil
|
||||
}
|
||||
|
||||
// hijackedConn wraps a net.Conn and is returned by setupHijackConn in the case
|
||||
|
@ -164,6 +114,13 @@ type hijackedConn struct {
|
|||
r *bufio.Reader
|
||||
}
|
||||
|
||||
func (c *hijackedConn) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
if err := req.Write(c.Conn); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return http.ReadResponse(c.r, req)
|
||||
}
|
||||
|
||||
func (c *hijackedConn) Read(b []byte) (int, error) {
|
||||
return c.r.Read(b)
|
||||
}
|
||||
|
|
|
@ -8,7 +8,6 @@ import (
|
|||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
@ -25,6 +24,7 @@ import (
|
|||
"github.com/docker/docker/pkg/archive"
|
||||
"github.com/docker/docker/pkg/authorization"
|
||||
"github.com/docker/docker/testutil/environment"
|
||||
"github.com/docker/go-connections/sockets"
|
||||
"gotest.tools/v3/assert"
|
||||
"gotest.tools/v3/skip"
|
||||
)
|
||||
|
@ -81,6 +81,17 @@ func isAllowed(reqURI string) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
func socketHTTPClient(u *url.URL) (*http.Client, error) {
|
||||
transport := &http.Transport{}
|
||||
err := sockets.ConfigureTransport(transport, u.Scheme, u.Path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &http.Client{
|
||||
Transport: transport,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func TestAuthZPluginAllowRequest(t *testing.T) {
|
||||
ctx := setupTestV1(t)
|
||||
|
||||
|
@ -176,15 +187,17 @@ func TestAuthZPluginAPIDenyResponse(t *testing.T) {
|
|||
daemonURL, err := url.Parse(d.Sock())
|
||||
assert.NilError(t, err)
|
||||
|
||||
conn, err := net.DialTimeout(daemonURL.Scheme, daemonURL.Path, time.Second*10)
|
||||
socketClient, err := socketHTTPClient(daemonURL)
|
||||
assert.NilError(t, err)
|
||||
c := httputil.NewClientConn(conn, nil)
|
||||
req, err := http.NewRequest(http.MethodGet, "/version", nil)
|
||||
assert.NilError(t, err)
|
||||
req = req.WithContext(ctx)
|
||||
resp, err := c.Do(req)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "/version", nil)
|
||||
assert.NilError(t, err)
|
||||
req.URL.Scheme = "http"
|
||||
req.URL.Host = client.DummyHost
|
||||
|
||||
resp, err := socketClient.Do(req)
|
||||
assert.NilError(t, err)
|
||||
|
||||
assert.DeepEqual(t, http.StatusForbidden, resp.StatusCode)
|
||||
}
|
||||
|
||||
|
@ -471,13 +484,15 @@ func TestAuthZPluginHeader(t *testing.T) {
|
|||
daemonURL, err := url.Parse(d.Sock())
|
||||
assert.NilError(t, err)
|
||||
|
||||
conn, err := net.DialTimeout(daemonURL.Scheme, daemonURL.Path, time.Second*10)
|
||||
socketClient, err := socketHTTPClient(daemonURL)
|
||||
assert.NilError(t, err)
|
||||
client := httputil.NewClientConn(conn, nil)
|
||||
req, err := http.NewRequest(http.MethodGet, "/version", nil)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "/version", nil)
|
||||
assert.NilError(t, err)
|
||||
req = req.WithContext(ctx)
|
||||
resp, err := client.Do(req)
|
||||
req.URL.Scheme = "http"
|
||||
req.URL.Host = client.DummyHost
|
||||
|
||||
resp, err := socketClient.Do(req)
|
||||
assert.NilError(t, err)
|
||||
assert.Equal(t, "application/json", resp.Header["Content-Type"][0])
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue