Browse Source

Merge pull request #45677 from thaJeztah/client_useragent

client: add WithUserAgent() option
Akihiro Suda 2 years ago
parent
commit
78cf11575d
4 changed files with 104 additions and 1 deletions
  1. 5 0
      client/client.go
  2. 10 0
      client/options.go
  3. 77 0
      client/options_test.go
  4. 12 1
      client/request.go

+ 5 - 0
client/client.go

@@ -76,6 +76,11 @@ type Client struct {
 	client *http.Client
 	client *http.Client
 	// version of the server to talk to.
 	// version of the server to talk to.
 	version string
 	version string
+	// userAgent is the User-Agent header to use for HTTP requests. It takes
+	// precedence over User-Agent headers set in customHTTPHeaders, and other
+	// header variables. When set to an empty string, the User-Agent header
+	// is removed, and no header is sent.
+	userAgent *string
 	// custom http headers configured by users.
 	// custom http headers configured by users.
 	customHTTPHeaders map[string]string
 	customHTTPHeaders map[string]string
 	// manualOverride is set to true when the version was set by users.
 	// manualOverride is set to true when the version was set by users.

+ 10 - 0
client/options.go

@@ -104,6 +104,16 @@ func WithTimeout(timeout time.Duration) Opt {
 	}
 	}
 }
 }
 
 
+// WithUserAgent configures the User-Agent header to use for HTTP requests.
+// It overrides any User-Agent set in headers. When set to an empty string,
+// the User-Agent header is removed, and no header is sent.
+func WithUserAgent(ua string) Opt {
+	return func(c *Client) error {
+		c.userAgent = &ua
+		return nil
+	}
+}
+
 // WithHTTPHeaders overrides the client default http headers
 // WithHTTPHeaders overrides the client default http headers
 func WithHTTPHeaders(headers map[string]string) Opt {
 func WithHTTPHeaders(headers map[string]string) Opt {
 	return func(c *Client) error {
 	return func(c *Client) error {

+ 77 - 0
client/options_test.go

@@ -1,6 +1,8 @@
 package client
 package client
 
 
 import (
 import (
+	"context"
+	"net/http"
 	"runtime"
 	"runtime"
 	"testing"
 	"testing"
 	"time"
 	"time"
@@ -59,3 +61,78 @@ func TestOptionWithVersionFromEnv(t *testing.T) {
 	assert.Equal(t, c.version, "2.9999")
 	assert.Equal(t, c.version, "2.9999")
 	assert.Equal(t, c.manualOverride, true)
 	assert.Equal(t, c.manualOverride, true)
 }
 }
+
+func TestWithUserAgent(t *testing.T) {
+	const userAgent = "Magic-Client/v1.2.3"
+	t.Run("user-agent", func(t *testing.T) {
+		c, err := NewClientWithOpts(
+			WithUserAgent(userAgent),
+			WithHTTPClient(newMockClient(func(req *http.Request) (*http.Response, error) {
+				assert.Check(t, is.Equal(req.Header.Get("User-Agent"), userAgent))
+				return &http.Response{StatusCode: http.StatusOK}, nil
+			})),
+		)
+		assert.Check(t, err)
+		_, err = c.Ping(context.Background())
+		assert.Check(t, err)
+		assert.Check(t, c.Close())
+	})
+	t.Run("user-agent and custom headers", func(t *testing.T) {
+		c, err := NewClientWithOpts(
+			WithUserAgent(userAgent),
+			WithHTTPHeaders(map[string]string{"User-Agent": "should-be-ignored/1.0.0", "Other-Header": "hello-world"}),
+			WithHTTPClient(newMockClient(func(req *http.Request) (*http.Response, error) {
+				assert.Check(t, is.Equal(req.Header.Get("User-Agent"), userAgent))
+				assert.Check(t, is.Equal(req.Header.Get("Other-Header"), "hello-world"))
+				return &http.Response{StatusCode: http.StatusOK}, nil
+			})),
+		)
+		assert.Check(t, err)
+		_, err = c.Ping(context.Background())
+		assert.Check(t, err)
+		assert.Check(t, c.Close())
+	})
+	t.Run("custom headers", func(t *testing.T) {
+		c, err := NewClientWithOpts(
+			WithHTTPHeaders(map[string]string{"User-Agent": "from-custom-headers/1.0.0", "Other-Header": "hello-world"}),
+			WithHTTPClient(newMockClient(func(req *http.Request) (*http.Response, error) {
+				assert.Check(t, is.Equal(req.Header.Get("User-Agent"), "from-custom-headers/1.0.0"))
+				assert.Check(t, is.Equal(req.Header.Get("Other-Header"), "hello-world"))
+				return &http.Response{StatusCode: http.StatusOK}, nil
+			})),
+		)
+		assert.Check(t, err)
+		_, err = c.Ping(context.Background())
+		assert.Check(t, err)
+		assert.Check(t, c.Close())
+	})
+	t.Run("no user-agent set", func(t *testing.T) {
+		c, err := NewClientWithOpts(
+			WithHTTPHeaders(map[string]string{"Other-Header": "hello-world"}),
+			WithHTTPClient(newMockClient(func(req *http.Request) (*http.Response, error) {
+				assert.Check(t, is.Equal(req.Header.Get("User-Agent"), ""))
+				assert.Check(t, is.Equal(req.Header.Get("Other-Header"), "hello-world"))
+				return &http.Response{StatusCode: http.StatusOK}, nil
+			})),
+		)
+		assert.Check(t, err)
+		_, err = c.Ping(context.Background())
+		assert.Check(t, err)
+		assert.Check(t, c.Close())
+	})
+	t.Run("reset custom user-agent", func(t *testing.T) {
+		c, err := NewClientWithOpts(
+			WithUserAgent(""),
+			WithHTTPHeaders(map[string]string{"User-Agent": "from-custom-headers/1.0.0", "Other-Header": "hello-world"}),
+			WithHTTPClient(newMockClient(func(req *http.Request) (*http.Response, error) {
+				assert.Check(t, is.Equal(req.Header.Get("User-Agent"), ""))
+				assert.Check(t, is.Equal(req.Header.Get("Other-Header"), "hello-world"))
+				return &http.Response{StatusCode: http.StatusOK}, nil
+			})),
+		)
+		assert.Check(t, err)
+		_, err = c.Ping(context.Background())
+		assert.Check(t, err)
+		assert.Check(t, c.Close())
+	})
+}

+ 12 - 1
client/request.go

@@ -107,7 +107,10 @@ func (cli *Client) buildRequest(method, path string, body io.Reader, headers hea
 
 
 	if cli.proto == "unix" || cli.proto == "npipe" {
 	if cli.proto == "unix" || cli.proto == "npipe" {
 		// For local communications, it doesn't matter what the host is. We just
 		// For local communications, it doesn't matter what the host is. We just
-		// need a valid and meaningful host name. (See #189)
+		// need a valid and meaningful host name. For details, see:
+		//
+		// - https://github.com/docker/engine-api/issues/189
+		// - https://github.com/golang/go/issues/13624
 		req.Host = "docker"
 		req.Host = "docker"
 	}
 	}
 
 
@@ -263,6 +266,14 @@ func (cli *Client) addHeaders(req *http.Request, headers headers) *http.Request
 	for k, v := range headers {
 	for k, v := range headers {
 		req.Header[http.CanonicalHeaderKey(k)] = v
 		req.Header[http.CanonicalHeaderKey(k)] = v
 	}
 	}
+
+	if cli.userAgent != nil {
+		if *cli.userAgent == "" {
+			req.Header.Del("User-Agent")
+		} else {
+			req.Header.Set("User-Agent", *cli.userAgent)
+		}
+	}
 	return req
 	return req
 }
 }