Browse Source

Merge pull request #46171 from thaJeztah/client_context

client: Client.buildRequest: use http.NewRequestWithContext
Sebastiaan van Stijn 1 năm trước cách đây
mục cha
commit
e0da5cb929
3 tập tin đã thay đổi với 20 bổ sung23 xóa
  1. 6 5
      client/hijack.go
  2. 5 7
      client/ping.go
  3. 9 11
      client/request.go

+ 6 - 5
client/hijack.go

@@ -21,11 +21,11 @@ func (cli *Client) postHijacked(ctx context.Context, path string, query url.Valu
 	if err != nil {
 		return types.HijackedResponse{}, err
 	}
-	req, err := cli.buildRequest(http.MethodPost, cli.getAPIPath(ctx, path, query), bodyEncoded, headers)
+	req, err := cli.buildRequest(ctx, http.MethodPost, cli.getAPIPath(ctx, path, query), bodyEncoded, headers)
 	if err != nil {
 		return types.HijackedResponse{}, err
 	}
-	conn, mediaType, err := cli.setupHijackConn(ctx, req, "tcp")
+	conn, mediaType, err := cli.setupHijackConn(req, "tcp")
 	if err != nil {
 		return types.HijackedResponse{}, err
 	}
@@ -35,17 +35,18 @@ func (cli *Client) postHijacked(ctx context.Context, path string, query url.Valu
 
 // DialHijack returns a hijacked connection with negotiated protocol proto.
 func (cli *Client) DialHijack(ctx context.Context, url, proto string, meta map[string][]string) (net.Conn, error) {
-	req, err := http.NewRequest(http.MethodPost, url, nil)
+	req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil)
 	if err != nil {
 		return nil, err
 	}
 	req = cli.addHeaders(req, meta)
 
-	conn, _, err := cli.setupHijackConn(ctx, req, proto)
+	conn, _, err := cli.setupHijackConn(req, proto)
 	return conn, err
 }
 
-func (cli *Client) setupHijackConn(ctx context.Context, req *http.Request, proto string) (net.Conn, string, error) {
+func (cli *Client) setupHijackConn(req *http.Request, proto string) (net.Conn, string, error) {
+	ctx := req.Context()
 	req.Header.Set("Connection", "Upgrade")
 	req.Header.Set("Upgrade", proto)
 

+ 5 - 7
client/ping.go

@@ -21,11 +21,11 @@ func (cli *Client) Ping(ctx context.Context) (types.Ping, error) {
 	// Using cli.buildRequest() + cli.doRequest() instead of cli.sendRequest()
 	// because ping requests are used during API version negotiation, so we want
 	// to hit the non-versioned /_ping endpoint, not /v1.xx/_ping
-	req, err := cli.buildRequest(http.MethodHead, path.Join(cli.basePath, "/_ping"), nil, nil)
+	req, err := cli.buildRequest(ctx, http.MethodHead, path.Join(cli.basePath, "/_ping"), nil, nil)
 	if err != nil {
 		return ping, err
 	}
-	serverResp, err := cli.doRequest(ctx, req)
+	serverResp, err := cli.doRequest(req)
 	if err == nil {
 		defer ensureReaderClosed(serverResp)
 		switch serverResp.statusCode {
@@ -37,11 +37,9 @@ func (cli *Client) Ping(ctx context.Context) (types.Ping, error) {
 		return ping, err
 	}
 
-	req, err = cli.buildRequest(http.MethodGet, path.Join(cli.basePath, "/_ping"), nil, nil)
-	if err != nil {
-		return ping, err
-	}
-	serverResp, err = cli.doRequest(ctx, req)
+	// HEAD failed; fallback to GET.
+	req.Method = http.MethodGet
+	serverResp, err = cli.doRequest(req)
 	defer ensureReaderClosed(serverResp)
 	if err != nil {
 		return ping, err

+ 9 - 11
client/request.go

@@ -96,8 +96,8 @@ func encodeBody(obj interface{}, headers http.Header) (io.Reader, http.Header, e
 	return body, headers, nil
 }
 
-func (cli *Client) buildRequest(method, path string, body io.Reader, headers http.Header) (*http.Request, error) {
-	req, err := http.NewRequest(method, path, body)
+func (cli *Client) buildRequest(ctx context.Context, method, path string, body io.Reader, headers http.Header) (*http.Request, error) {
+	req, err := http.NewRequestWithContext(ctx, method, path, body)
 	if err != nil {
 		return nil, err
 	}
@@ -117,12 +117,12 @@ func (cli *Client) buildRequest(method, path string, body io.Reader, headers htt
 }
 
 func (cli *Client) sendRequest(ctx context.Context, method, path string, query url.Values, body io.Reader, headers http.Header) (serverResponse, error) {
-	req, err := cli.buildRequest(method, cli.getAPIPath(ctx, path, query), body, headers)
+	req, err := cli.buildRequest(ctx, method, cli.getAPIPath(ctx, path, query), body, headers)
 	if err != nil {
 		return serverResponse{}, err
 	}
 
-	resp, err := cli.doRequest(ctx, req)
+	resp, err := cli.doRequest(req)
 	switch {
 	case errors.Is(err, context.Canceled):
 		return serverResponse{}, errdefs.Cancelled(err)
@@ -134,10 +134,9 @@ func (cli *Client) sendRequest(ctx context.Context, method, path string, query u
 	return resp, errdefs.FromStatusCode(err, resp.statusCode)
 }
 
-func (cli *Client) doRequest(ctx context.Context, req *http.Request) (serverResponse, error) {
+func (cli *Client) doRequest(req *http.Request) (serverResponse, error) {
 	serverResp := serverResponse{statusCode: -1, reqURL: req.URL}
 
-	req = req.WithContext(ctx)
 	resp, err := cli.client.Do(req)
 	if err != nil {
 		if cli.scheme != "https" && strings.Contains(err.Error(), "malformed HTTP response") {
@@ -227,18 +226,17 @@ func (cli *Client) checkResponseErr(serverResp serverResponse) error {
 		return fmt.Errorf("request returned %s for API route and version %s, check if the server supports the requested API version", http.StatusText(serverResp.statusCode), serverResp.reqURL)
 	}
 
-	var errorMessage string
+	var daemonErr error
 	if serverResp.header.Get("Content-Type") == "application/json" && (cli.version == "" || versions.GreaterThan(cli.version, "1.23")) {
 		var errorResponse types.ErrorResponse
 		if err := json.Unmarshal(body, &errorResponse); err != nil {
 			return errors.Wrap(err, "Error reading JSON")
 		}
-		errorMessage = strings.TrimSpace(errorResponse.Message)
+		daemonErr = errors.New(strings.TrimSpace(errorResponse.Message))
 	} else {
-		errorMessage = strings.TrimSpace(string(body))
+		daemonErr = errors.New(strings.TrimSpace(string(body)))
 	}
-
-	return errors.Wrap(errors.New(errorMessage), "Error response from daemon")
+	return errors.Wrap(daemonErr, "Error response from daemon")
 }
 
 func (cli *Client) addHeaders(req *http.Request, headers http.Header) *http.Request {