浏览代码

Merge pull request #46920 from dmcgowan/client-hijack-cleanup

Replace use of httputil in client hijack
Sebastiaan van Stijn 1 年之前
父节点
当前提交
0751141003
共有 3 个文件被更改,包括 52 次插入88 次删除
  1. 0 8
      .golangci.yml
  2. 26 69
      client/hijack.go
  3. 26 11
      integration/plugin/authz/authz_plugin_test.go

+ 0 - 8
.golangci.yml

@@ -113,14 +113,6 @@ issues:
       path: "api/types/(volume|container)/"
       linters:
         - revive
-    # FIXME temporarily suppress these. See #39926
-    - text: "SA1019: httputil.NewClientConn"
-      linters:
-        - staticcheck
-    # FIXME temporarily suppress these (related to the ones above)
-    - text: "SA1019: httputil.ErrPersistEOF"
-      linters:
-        - staticcheck
     # FIXME temporarily suppress these (see https://github.com/gotestyourself/gotest.tools/issues/272)
     - text: "SA1019: (assert|cmp|is)\\.ErrorType is deprecated"
       linters:

+ 26 - 69
client/hijack.go

@@ -6,17 +6,13 @@ import (
 	"fmt"
 	"net"
 	"net/http"
-	"net/http/httputil"
 	"net/url"
 	"time"
 
 	"github.com/docker/docker/api/types"
 	"github.com/docker/docker/api/types/versions"
 	"github.com/pkg/errors"
-	"go.opentelemetry.io/otel"
-	"go.opentelemetry.io/otel/codes"
-	"go.opentelemetry.io/otel/propagation"
-	"go.opentelemetry.io/otel/trace"
+	"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
 )
 
 // postHijacked sends a POST request and hijacks the connection.
@@ -54,33 +50,16 @@ func (cli *Client) setupHijackConn(req *http.Request, proto string) (_ net.Conn,
 	req.Header.Set("Connection", "Upgrade")
 	req.Header.Set("Upgrade", proto)
 
-	// We aren't using the configured RoundTripper here so manually inject the trace context
-	tp := cli.tp
-	if tp == nil {
-		if span := trace.SpanFromContext(ctx); span.SpanContext().IsValid() {
-			tp = span.TracerProvider()
-		} else {
-			tp = otel.GetTracerProvider()
-		}
-	}
-
-	ctx, span := tp.Tracer("").Start(ctx, req.Method+" "+req.URL.Path, trace.WithSpanKind(trace.SpanKindClient))
-	// FIXME(thaJeztah): httpconv.ClientRequest is now an internal package; replace this with alternative for semconv v1.21
-	// span.SetAttributes(httpconv.ClientRequest(req)...)
-	defer func() {
-		if retErr != nil {
-			span.RecordError(retErr)
-			span.SetStatus(codes.Error, retErr.Error())
-		}
-		span.End()
-	}()
-	otel.GetTextMapPropagator().Inject(ctx, propagation.HeaderCarrier(req.Header))
-
 	dialer := cli.Dialer()
 	conn, err := dialer(ctx)
 	if err != nil {
 		return nil, "", errors.Wrap(err, "cannot connect to the Docker daemon. Is 'docker daemon' running on this host?")
 	}
+	defer func() {
+		if retErr != nil {
+			conn.Close()
+		}
+	}()
 
 	// When we set up a TCP connection for hijack, there could be long periods
 	// of inactivity (a long running command with no output) that in certain
@@ -92,58 +71,29 @@ func (cli *Client) setupHijackConn(req *http.Request, proto string) (_ net.Conn,
 		_ = tcpConn.SetKeepAlivePeriod(30 * time.Second)
 	}
 
-	clientconn := httputil.NewClientConn(conn, nil)
-	defer clientconn.Close()
+	hc := &hijackedConn{conn, bufio.NewReader(conn)}
 
 	// Server hijacks the connection, error 'connection closed' expected
-	resp, err := clientconn.Do(req)
-	if resp != nil {
-		// This is a simplified variant of "httpconv.ClientStatus(resp.StatusCode))";
-		//
-		// The main purpose of httpconv.ClientStatus() is to detect whether the
-		// status was successful (1xx, 2xx, 3xx) or non-successful (4xx/5xx).
-		//
-		// It also provides complex logic to *validate* status-codes against
-		// a hard-coded list meant to exclude "bogus" status codes in "success"
-		// ranges (1xx, 2xx) and convert them into an error status. That code
-		// seemed over-reaching (and not accounting for potential future valid
-		// status codes). We assume we only get valid status codes, and only
-		// look at status-code ranges.
-		//
-		// For reference, see:
-		// https://github.com/open-telemetry/opentelemetry-go/blob/v1.21.0/semconv/v1.17.0/httpconv/http.go#L85-L89
-		// https://github.com/open-telemetry/opentelemetry-go/blob/v1.21.0/semconv/internal/v2/http.go#L322-L330
-		// https://github.com/open-telemetry/opentelemetry-go/blob/v1.21.0/semconv/internal/v2/http.go#L356-L404
-		code := codes.Unset
-		if resp.StatusCode >= http.StatusBadRequest {
-			code = codes.Error
-		}
-		span.SetStatus(code, "")
+	resp, err := otelhttp.NewTransport(hc).RoundTrip(req)
+	if err != nil {
+		return nil, "", err
 	}
-
-	//nolint:staticcheck // ignore SA1019 for connecting to old (pre go1.8) daemons
-	if err != httputil.ErrPersistEOF {
-		if err != nil {
-			return nil, "", err
-		}
-		if resp.StatusCode != http.StatusSwitchingProtocols {
-			_ = resp.Body.Close()
-			return nil, "", fmt.Errorf("unable to upgrade to %s, received %d", proto, resp.StatusCode)
-		}
+	if resp.StatusCode != http.StatusSwitchingProtocols {
+		_ = resp.Body.Close()
+		return nil, "", fmt.Errorf("unable to upgrade to %s, received %d", proto, resp.StatusCode)
 	}
 
-	c, br := clientconn.Hijack()
-	if br.Buffered() > 0 {
+	if hc.r.Buffered() > 0 {
 		// If there is buffered content, wrap the connection.  We return an
 		// object that implements CloseWrite if the underlying connection
 		// implements it.
-		if _, ok := c.(types.CloseWriter); ok {
-			c = &hijackedConnCloseWriter{&hijackedConn{c, br}}
+		if _, ok := hc.Conn.(types.CloseWriter); ok {
+			conn = &hijackedConnCloseWriter{hc}
 		} else {
-			c = &hijackedConn{c, br}
+			conn = hc
 		}
 	} else {
-		br.Reset(nil)
+		hc.r.Reset(nil)
 	}
 
 	var mediaType string
@@ -152,7 +102,7 @@ func (cli *Client) setupHijackConn(req *http.Request, proto string) (_ net.Conn,
 		mediaType = resp.Header.Get("Content-Type")
 	}
 
-	return c, mediaType, nil
+	return conn, mediaType, nil
 }
 
 // hijackedConn wraps a net.Conn and is returned by setupHijackConn in the case
@@ -164,6 +114,13 @@ type hijackedConn struct {
 	r *bufio.Reader
 }
 
+func (c *hijackedConn) RoundTrip(req *http.Request) (*http.Response, error) {
+	if err := req.Write(c.Conn); err != nil {
+		return nil, err
+	}
+	return http.ReadResponse(c.r, req)
+}
+
 func (c *hijackedConn) Read(b []byte) (int, error) {
 	return c.r.Read(b)
 }

+ 26 - 11
integration/plugin/authz/authz_plugin_test.go

@@ -8,7 +8,6 @@ import (
 	"io"
 	"net"
 	"net/http"
-	"net/http/httputil"
 	"net/url"
 	"os"
 	"path/filepath"
@@ -25,6 +24,7 @@ import (
 	"github.com/docker/docker/pkg/archive"
 	"github.com/docker/docker/pkg/authorization"
 	"github.com/docker/docker/testutil/environment"
+	"github.com/docker/go-connections/sockets"
 	"gotest.tools/v3/assert"
 	"gotest.tools/v3/skip"
 )
@@ -81,6 +81,17 @@ func isAllowed(reqURI string) bool {
 	return false
 }
 
+func socketHTTPClient(u *url.URL) (*http.Client, error) {
+	transport := &http.Transport{}
+	err := sockets.ConfigureTransport(transport, u.Scheme, u.Path)
+	if err != nil {
+		return nil, err
+	}
+	return &http.Client{
+		Transport: transport,
+	}, nil
+}
+
 func TestAuthZPluginAllowRequest(t *testing.T) {
 	ctx := setupTestV1(t)
 
@@ -176,15 +187,17 @@ func TestAuthZPluginAPIDenyResponse(t *testing.T) {
 	daemonURL, err := url.Parse(d.Sock())
 	assert.NilError(t, err)
 
-	conn, err := net.DialTimeout(daemonURL.Scheme, daemonURL.Path, time.Second*10)
+	socketClient, err := socketHTTPClient(daemonURL)
 	assert.NilError(t, err)
-	c := httputil.NewClientConn(conn, nil)
-	req, err := http.NewRequest(http.MethodGet, "/version", nil)
+
+	req, err := http.NewRequestWithContext(ctx, http.MethodGet, "/version", nil)
 	assert.NilError(t, err)
-	req = req.WithContext(ctx)
-	resp, err := c.Do(req)
+	req.URL.Scheme = "http"
+	req.URL.Host = client.DummyHost
 
+	resp, err := socketClient.Do(req)
 	assert.NilError(t, err)
+
 	assert.DeepEqual(t, http.StatusForbidden, resp.StatusCode)
 }
 
@@ -471,13 +484,15 @@ func TestAuthZPluginHeader(t *testing.T) {
 	daemonURL, err := url.Parse(d.Sock())
 	assert.NilError(t, err)
 
-	conn, err := net.DialTimeout(daemonURL.Scheme, daemonURL.Path, time.Second*10)
+	socketClient, err := socketHTTPClient(daemonURL)
 	assert.NilError(t, err)
-	client := httputil.NewClientConn(conn, nil)
-	req, err := http.NewRequest(http.MethodGet, "/version", nil)
+
+	req, err := http.NewRequestWithContext(ctx, http.MethodGet, "/version", nil)
 	assert.NilError(t, err)
-	req = req.WithContext(ctx)
-	resp, err := client.Do(req)
+	req.URL.Scheme = "http"
+	req.URL.Host = client.DummyHost
+
+	resp, err := socketClient.Do(req)
 	assert.NilError(t, err)
 	assert.Equal(t, "application/json", resp.Header["Content-Type"][0])
 }