Bläddra i källkod

Merge pull request #46920 from dmcgowan/client-hijack-cleanup

Replace use of httputil in client hijack
Sebastiaan van Stijn 1 år sedan
förälder
incheckning
0751141003
3 ändrade filer med 52 tillägg och 88 borttagningar
  1. 0 8
      .golangci.yml
  2. 26 69
      client/hijack.go
  3. 26 11
      integration/plugin/authz/authz_plugin_test.go

+ 0 - 8
.golangci.yml

@@ -113,14 +113,6 @@ issues:
       path: "api/types/(volume|container)/"
       path: "api/types/(volume|container)/"
       linters:
       linters:
         - revive
         - revive
-    # FIXME temporarily suppress these. See #39926
-    - text: "SA1019: httputil.NewClientConn"
-      linters:
-        - staticcheck
-    # FIXME temporarily suppress these (related to the ones above)
-    - text: "SA1019: httputil.ErrPersistEOF"
-      linters:
-        - staticcheck
     # FIXME temporarily suppress these (see https://github.com/gotestyourself/gotest.tools/issues/272)
     # FIXME temporarily suppress these (see https://github.com/gotestyourself/gotest.tools/issues/272)
     - text: "SA1019: (assert|cmp|is)\\.ErrorType is deprecated"
     - text: "SA1019: (assert|cmp|is)\\.ErrorType is deprecated"
       linters:
       linters:

+ 26 - 69
client/hijack.go

@@ -6,17 +6,13 @@ import (
 	"fmt"
 	"fmt"
 	"net"
 	"net"
 	"net/http"
 	"net/http"
-	"net/http/httputil"
 	"net/url"
 	"net/url"
 	"time"
 	"time"
 
 
 	"github.com/docker/docker/api/types"
 	"github.com/docker/docker/api/types"
 	"github.com/docker/docker/api/types/versions"
 	"github.com/docker/docker/api/types/versions"
 	"github.com/pkg/errors"
 	"github.com/pkg/errors"
-	"go.opentelemetry.io/otel"
-	"go.opentelemetry.io/otel/codes"
-	"go.opentelemetry.io/otel/propagation"
-	"go.opentelemetry.io/otel/trace"
+	"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
 )
 )
 
 
 // postHijacked sends a POST request and hijacks the connection.
 // postHijacked sends a POST request and hijacks the connection.
@@ -54,33 +50,16 @@ func (cli *Client) setupHijackConn(req *http.Request, proto string) (_ net.Conn,
 	req.Header.Set("Connection", "Upgrade")
 	req.Header.Set("Connection", "Upgrade")
 	req.Header.Set("Upgrade", proto)
 	req.Header.Set("Upgrade", proto)
 
 
-	// We aren't using the configured RoundTripper here so manually inject the trace context
-	tp := cli.tp
-	if tp == nil {
-		if span := trace.SpanFromContext(ctx); span.SpanContext().IsValid() {
-			tp = span.TracerProvider()
-		} else {
-			tp = otel.GetTracerProvider()
-		}
-	}
-
-	ctx, span := tp.Tracer("").Start(ctx, req.Method+" "+req.URL.Path, trace.WithSpanKind(trace.SpanKindClient))
-	// FIXME(thaJeztah): httpconv.ClientRequest is now an internal package; replace this with alternative for semconv v1.21
-	// span.SetAttributes(httpconv.ClientRequest(req)...)
-	defer func() {
-		if retErr != nil {
-			span.RecordError(retErr)
-			span.SetStatus(codes.Error, retErr.Error())
-		}
-		span.End()
-	}()
-	otel.GetTextMapPropagator().Inject(ctx, propagation.HeaderCarrier(req.Header))
-
 	dialer := cli.Dialer()
 	dialer := cli.Dialer()
 	conn, err := dialer(ctx)
 	conn, err := dialer(ctx)
 	if err != nil {
 	if err != nil {
 		return nil, "", errors.Wrap(err, "cannot connect to the Docker daemon. Is 'docker daemon' running on this host?")
 		return nil, "", errors.Wrap(err, "cannot connect to the Docker daemon. Is 'docker daemon' running on this host?")
 	}
 	}
+	defer func() {
+		if retErr != nil {
+			conn.Close()
+		}
+	}()
 
 
 	// When we set up a TCP connection for hijack, there could be long periods
 	// When we set up a TCP connection for hijack, there could be long periods
 	// of inactivity (a long running command with no output) that in certain
 	// of inactivity (a long running command with no output) that in certain
@@ -92,58 +71,29 @@ func (cli *Client) setupHijackConn(req *http.Request, proto string) (_ net.Conn,
 		_ = tcpConn.SetKeepAlivePeriod(30 * time.Second)
 		_ = tcpConn.SetKeepAlivePeriod(30 * time.Second)
 	}
 	}
 
 
-	clientconn := httputil.NewClientConn(conn, nil)
-	defer clientconn.Close()
+	hc := &hijackedConn{conn, bufio.NewReader(conn)}
 
 
 	// Server hijacks the connection, error 'connection closed' expected
 	// Server hijacks the connection, error 'connection closed' expected
-	resp, err := clientconn.Do(req)
-	if resp != nil {
-		// This is a simplified variant of "httpconv.ClientStatus(resp.StatusCode))";
-		//
-		// The main purpose of httpconv.ClientStatus() is to detect whether the
-		// status was successful (1xx, 2xx, 3xx) or non-successful (4xx/5xx).
-		//
-		// It also provides complex logic to *validate* status-codes against
-		// a hard-coded list meant to exclude "bogus" status codes in "success"
-		// ranges (1xx, 2xx) and convert them into an error status. That code
-		// seemed over-reaching (and not accounting for potential future valid
-		// status codes). We assume we only get valid status codes, and only
-		// look at status-code ranges.
-		//
-		// For reference, see:
-		// https://github.com/open-telemetry/opentelemetry-go/blob/v1.21.0/semconv/v1.17.0/httpconv/http.go#L85-L89
-		// https://github.com/open-telemetry/opentelemetry-go/blob/v1.21.0/semconv/internal/v2/http.go#L322-L330
-		// https://github.com/open-telemetry/opentelemetry-go/blob/v1.21.0/semconv/internal/v2/http.go#L356-L404
-		code := codes.Unset
-		if resp.StatusCode >= http.StatusBadRequest {
-			code = codes.Error
-		}
-		span.SetStatus(code, "")
+	resp, err := otelhttp.NewTransport(hc).RoundTrip(req)
+	if err != nil {
+		return nil, "", err
 	}
 	}
-
-	//nolint:staticcheck // ignore SA1019 for connecting to old (pre go1.8) daemons
-	if err != httputil.ErrPersistEOF {
-		if err != nil {
-			return nil, "", err
-		}
-		if resp.StatusCode != http.StatusSwitchingProtocols {
-			_ = resp.Body.Close()
-			return nil, "", fmt.Errorf("unable to upgrade to %s, received %d", proto, resp.StatusCode)
-		}
+	if resp.StatusCode != http.StatusSwitchingProtocols {
+		_ = resp.Body.Close()
+		return nil, "", fmt.Errorf("unable to upgrade to %s, received %d", proto, resp.StatusCode)
 	}
 	}
 
 
-	c, br := clientconn.Hijack()
-	if br.Buffered() > 0 {
+	if hc.r.Buffered() > 0 {
 		// If there is buffered content, wrap the connection.  We return an
 		// If there is buffered content, wrap the connection.  We return an
 		// object that implements CloseWrite if the underlying connection
 		// object that implements CloseWrite if the underlying connection
 		// implements it.
 		// implements it.
-		if _, ok := c.(types.CloseWriter); ok {
-			c = &hijackedConnCloseWriter{&hijackedConn{c, br}}
+		if _, ok := hc.Conn.(types.CloseWriter); ok {
+			conn = &hijackedConnCloseWriter{hc}
 		} else {
 		} else {
-			c = &hijackedConn{c, br}
+			conn = hc
 		}
 		}
 	} else {
 	} else {
-		br.Reset(nil)
+		hc.r.Reset(nil)
 	}
 	}
 
 
 	var mediaType string
 	var mediaType string
@@ -152,7 +102,7 @@ func (cli *Client) setupHijackConn(req *http.Request, proto string) (_ net.Conn,
 		mediaType = resp.Header.Get("Content-Type")
 		mediaType = resp.Header.Get("Content-Type")
 	}
 	}
 
 
-	return c, mediaType, nil
+	return conn, mediaType, nil
 }
 }
 
 
 // hijackedConn wraps a net.Conn and is returned by setupHijackConn in the case
 // hijackedConn wraps a net.Conn and is returned by setupHijackConn in the case
@@ -164,6 +114,13 @@ type hijackedConn struct {
 	r *bufio.Reader
 	r *bufio.Reader
 }
 }
 
 
+func (c *hijackedConn) RoundTrip(req *http.Request) (*http.Response, error) {
+	if err := req.Write(c.Conn); err != nil {
+		return nil, err
+	}
+	return http.ReadResponse(c.r, req)
+}
+
 func (c *hijackedConn) Read(b []byte) (int, error) {
 func (c *hijackedConn) Read(b []byte) (int, error) {
 	return c.r.Read(b)
 	return c.r.Read(b)
 }
 }

+ 26 - 11
integration/plugin/authz/authz_plugin_test.go

@@ -8,7 +8,6 @@ import (
 	"io"
 	"io"
 	"net"
 	"net"
 	"net/http"
 	"net/http"
-	"net/http/httputil"
 	"net/url"
 	"net/url"
 	"os"
 	"os"
 	"path/filepath"
 	"path/filepath"
@@ -25,6 +24,7 @@ import (
 	"github.com/docker/docker/pkg/archive"
 	"github.com/docker/docker/pkg/archive"
 	"github.com/docker/docker/pkg/authorization"
 	"github.com/docker/docker/pkg/authorization"
 	"github.com/docker/docker/testutil/environment"
 	"github.com/docker/docker/testutil/environment"
+	"github.com/docker/go-connections/sockets"
 	"gotest.tools/v3/assert"
 	"gotest.tools/v3/assert"
 	"gotest.tools/v3/skip"
 	"gotest.tools/v3/skip"
 )
 )
@@ -81,6 +81,17 @@ func isAllowed(reqURI string) bool {
 	return false
 	return false
 }
 }
 
 
+func socketHTTPClient(u *url.URL) (*http.Client, error) {
+	transport := &http.Transport{}
+	err := sockets.ConfigureTransport(transport, u.Scheme, u.Path)
+	if err != nil {
+		return nil, err
+	}
+	return &http.Client{
+		Transport: transport,
+	}, nil
+}
+
 func TestAuthZPluginAllowRequest(t *testing.T) {
 func TestAuthZPluginAllowRequest(t *testing.T) {
 	ctx := setupTestV1(t)
 	ctx := setupTestV1(t)
 
 
@@ -176,15 +187,17 @@ func TestAuthZPluginAPIDenyResponse(t *testing.T) {
 	daemonURL, err := url.Parse(d.Sock())
 	daemonURL, err := url.Parse(d.Sock())
 	assert.NilError(t, err)
 	assert.NilError(t, err)
 
 
-	conn, err := net.DialTimeout(daemonURL.Scheme, daemonURL.Path, time.Second*10)
+	socketClient, err := socketHTTPClient(daemonURL)
 	assert.NilError(t, err)
 	assert.NilError(t, err)
-	c := httputil.NewClientConn(conn, nil)
-	req, err := http.NewRequest(http.MethodGet, "/version", nil)
+
+	req, err := http.NewRequestWithContext(ctx, http.MethodGet, "/version", nil)
 	assert.NilError(t, err)
 	assert.NilError(t, err)
-	req = req.WithContext(ctx)
-	resp, err := c.Do(req)
+	req.URL.Scheme = "http"
+	req.URL.Host = client.DummyHost
 
 
+	resp, err := socketClient.Do(req)
 	assert.NilError(t, err)
 	assert.NilError(t, err)
+
 	assert.DeepEqual(t, http.StatusForbidden, resp.StatusCode)
 	assert.DeepEqual(t, http.StatusForbidden, resp.StatusCode)
 }
 }
 
 
@@ -471,13 +484,15 @@ func TestAuthZPluginHeader(t *testing.T) {
 	daemonURL, err := url.Parse(d.Sock())
 	daemonURL, err := url.Parse(d.Sock())
 	assert.NilError(t, err)
 	assert.NilError(t, err)
 
 
-	conn, err := net.DialTimeout(daemonURL.Scheme, daemonURL.Path, time.Second*10)
+	socketClient, err := socketHTTPClient(daemonURL)
 	assert.NilError(t, err)
 	assert.NilError(t, err)
-	client := httputil.NewClientConn(conn, nil)
-	req, err := http.NewRequest(http.MethodGet, "/version", nil)
+
+	req, err := http.NewRequestWithContext(ctx, http.MethodGet, "/version", nil)
 	assert.NilError(t, err)
 	assert.NilError(t, err)
-	req = req.WithContext(ctx)
-	resp, err := client.Do(req)
+	req.URL.Scheme = "http"
+	req.URL.Host = client.DummyHost
+
+	resp, err := socketClient.Do(req)
 	assert.NilError(t, err)
 	assert.NilError(t, err)
 	assert.Equal(t, "application/json", resp.Header["Content-Type"][0])
 	assert.Equal(t, "application/json", resp.Header["Content-Type"][0])
 }
 }