From 4cc796ab93a3f31a52c09450fab11865b94d4633 Mon Sep 17 00:00:00 2001 From: Sebastiaan van Stijn Date: Fri, 14 Jul 2023 18:23:59 +0200 Subject: [PATCH] client: Client.buildRequest: use http.NewRequestWithContext Attach the context to the request while we're creating it, instead of creating the context first, and adding the context later. Signed-off-by: Sebastiaan van Stijn --- client/hijack.go | 11 ++++++----- client/ping.go | 4 ++-- client/request.go | 11 +++++------ 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/client/hijack.go b/client/hijack.go index 2337a9539b..3fc17dc071 100644 --- a/client/hijack.go +++ b/client/hijack.go @@ -21,11 +21,11 @@ func (cli *Client) postHijacked(ctx context.Context, path string, query url.Valu if err != nil { return types.HijackedResponse{}, err } - req, err := cli.buildRequest(http.MethodPost, cli.getAPIPath(ctx, path, query), bodyEncoded, headers) + req, err := cli.buildRequest(ctx, http.MethodPost, cli.getAPIPath(ctx, path, query), bodyEncoded, headers) if err != nil { return types.HijackedResponse{}, err } - conn, mediaType, err := cli.setupHijackConn(ctx, req, "tcp") + conn, mediaType, err := cli.setupHijackConn(req, "tcp") if err != nil { return types.HijackedResponse{}, err } @@ -35,17 +35,18 @@ func (cli *Client) postHijacked(ctx context.Context, path string, query url.Valu // DialHijack returns a hijacked connection with negotiated protocol proto. func (cli *Client) DialHijack(ctx context.Context, url, proto string, meta map[string][]string) (net.Conn, error) { - req, err := http.NewRequest(http.MethodPost, url, nil) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil) if err != nil { return nil, err } req = cli.addHeaders(req, meta) - conn, _, err := cli.setupHijackConn(ctx, req, proto) + conn, _, err := cli.setupHijackConn(req, proto) return conn, err } -func (cli *Client) setupHijackConn(ctx context.Context, req *http.Request, proto string) (net.Conn, string, error) { +func (cli *Client) setupHijackConn(req *http.Request, proto string) (net.Conn, string, error) { + ctx := req.Context() req.Header.Set("Connection", "Upgrade") req.Header.Set("Upgrade", proto) diff --git a/client/ping.go b/client/ping.go index 7539e54fef..dfd1042fab 100644 --- a/client/ping.go +++ b/client/ping.go @@ -21,11 +21,11 @@ func (cli *Client) Ping(ctx context.Context) (types.Ping, error) { // Using cli.buildRequest() + cli.doRequest() instead of cli.sendRequest() // because ping requests are used during API version negotiation, so we want // to hit the non-versioned /_ping endpoint, not /v1.xx/_ping - req, err := cli.buildRequest(http.MethodHead, path.Join(cli.basePath, "/_ping"), nil, nil) + req, err := cli.buildRequest(ctx, http.MethodHead, path.Join(cli.basePath, "/_ping"), nil, nil) if err != nil { return ping, err } - serverResp, err := cli.doRequest(ctx, req) + serverResp, err := cli.doRequest(req) if err == nil { defer ensureReaderClosed(serverResp) switch serverResp.statusCode { diff --git a/client/request.go b/client/request.go index a8c9437c5e..efe07bb9ea 100644 --- a/client/request.go +++ b/client/request.go @@ -96,8 +96,8 @@ func encodeBody(obj interface{}, headers http.Header) (io.Reader, http.Header, e return body, headers, nil } -func (cli *Client) buildRequest(method, path string, body io.Reader, headers http.Header) (*http.Request, error) { - req, err := http.NewRequest(method, path, body) +func (cli *Client) buildRequest(ctx context.Context, method, path string, body io.Reader, headers http.Header) (*http.Request, error) { + req, err := http.NewRequestWithContext(ctx, method, path, body) if err != nil { return nil, err } @@ -117,12 +117,12 @@ func (cli *Client) buildRequest(method, path string, body io.Reader, headers htt } func (cli *Client) sendRequest(ctx context.Context, method, path string, query url.Values, body io.Reader, headers http.Header) (serverResponse, error) { - req, err := cli.buildRequest(method, cli.getAPIPath(ctx, path, query), body, headers) + req, err := cli.buildRequest(ctx, method, cli.getAPIPath(ctx, path, query), body, headers) if err != nil { return serverResponse{}, err } - resp, err := cli.doRequest(ctx, req) + resp, err := cli.doRequest(req) switch { case errors.Is(err, context.Canceled): return serverResponse{}, errdefs.Cancelled(err) @@ -134,10 +134,9 @@ func (cli *Client) sendRequest(ctx context.Context, method, path string, query u return resp, errdefs.FromStatusCode(err, resp.statusCode) } -func (cli *Client) doRequest(ctx context.Context, req *http.Request) (serverResponse, error) { +func (cli *Client) doRequest(req *http.Request) (serverResponse, error) { serverResp := serverResponse{statusCode: -1, reqURL: req.URL} - req = req.WithContext(ctx) resp, err := cli.client.Do(req) if err != nil { if cli.scheme != "https" && strings.Contains(err.Error(), "malformed HTTP response") {