client: Client.buildRequest: use http.NewRequestWithContext
Attach the context to the request while we're creating it, instead of creating the context first, and adding the context later. Signed-off-by: Sebastiaan van Stijn <github@gone.nl>
This commit is contained in:
parent
58dc0fcd1e
commit
4cc796ab93
3 changed files with 13 additions and 13 deletions
|
@ -21,11 +21,11 @@ func (cli *Client) postHijacked(ctx context.Context, path string, query url.Valu
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.HijackedResponse{}, err
|
return types.HijackedResponse{}, err
|
||||||
}
|
}
|
||||||
req, err := cli.buildRequest(http.MethodPost, cli.getAPIPath(ctx, path, query), bodyEncoded, headers)
|
req, err := cli.buildRequest(ctx, http.MethodPost, cli.getAPIPath(ctx, path, query), bodyEncoded, headers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.HijackedResponse{}, err
|
return types.HijackedResponse{}, err
|
||||||
}
|
}
|
||||||
conn, mediaType, err := cli.setupHijackConn(ctx, req, "tcp")
|
conn, mediaType, err := cli.setupHijackConn(req, "tcp")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.HijackedResponse{}, err
|
return types.HijackedResponse{}, err
|
||||||
}
|
}
|
||||||
|
@ -35,17 +35,18 @@ func (cli *Client) postHijacked(ctx context.Context, path string, query url.Valu
|
||||||
|
|
||||||
// DialHijack returns a hijacked connection with negotiated protocol proto.
|
// DialHijack returns a hijacked connection with negotiated protocol proto.
|
||||||
func (cli *Client) DialHijack(ctx context.Context, url, proto string, meta map[string][]string) (net.Conn, error) {
|
func (cli *Client) DialHijack(ctx context.Context, url, proto string, meta map[string][]string) (net.Conn, error) {
|
||||||
req, err := http.NewRequest(http.MethodPost, url, nil)
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
req = cli.addHeaders(req, meta)
|
req = cli.addHeaders(req, meta)
|
||||||
|
|
||||||
conn, _, err := cli.setupHijackConn(ctx, req, proto)
|
conn, _, err := cli.setupHijackConn(req, proto)
|
||||||
return conn, err
|
return conn, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cli *Client) setupHijackConn(ctx context.Context, req *http.Request, proto string) (net.Conn, string, error) {
|
func (cli *Client) setupHijackConn(req *http.Request, proto string) (net.Conn, string, error) {
|
||||||
|
ctx := req.Context()
|
||||||
req.Header.Set("Connection", "Upgrade")
|
req.Header.Set("Connection", "Upgrade")
|
||||||
req.Header.Set("Upgrade", proto)
|
req.Header.Set("Upgrade", proto)
|
||||||
|
|
||||||
|
|
|
@ -21,11 +21,11 @@ func (cli *Client) Ping(ctx context.Context) (types.Ping, error) {
|
||||||
// Using cli.buildRequest() + cli.doRequest() instead of cli.sendRequest()
|
// Using cli.buildRequest() + cli.doRequest() instead of cli.sendRequest()
|
||||||
// because ping requests are used during API version negotiation, so we want
|
// because ping requests are used during API version negotiation, so we want
|
||||||
// to hit the non-versioned /_ping endpoint, not /v1.xx/_ping
|
// to hit the non-versioned /_ping endpoint, not /v1.xx/_ping
|
||||||
req, err := cli.buildRequest(http.MethodHead, path.Join(cli.basePath, "/_ping"), nil, nil)
|
req, err := cli.buildRequest(ctx, http.MethodHead, path.Join(cli.basePath, "/_ping"), nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ping, err
|
return ping, err
|
||||||
}
|
}
|
||||||
serverResp, err := cli.doRequest(ctx, req)
|
serverResp, err := cli.doRequest(req)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
defer ensureReaderClosed(serverResp)
|
defer ensureReaderClosed(serverResp)
|
||||||
switch serverResp.statusCode {
|
switch serverResp.statusCode {
|
||||||
|
|
|
@ -96,8 +96,8 @@ func encodeBody(obj interface{}, headers http.Header) (io.Reader, http.Header, e
|
||||||
return body, headers, nil
|
return body, headers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cli *Client) buildRequest(method, path string, body io.Reader, headers http.Header) (*http.Request, error) {
|
func (cli *Client) buildRequest(ctx context.Context, method, path string, body io.Reader, headers http.Header) (*http.Request, error) {
|
||||||
req, err := http.NewRequest(method, path, body)
|
req, err := http.NewRequestWithContext(ctx, method, path, body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -117,12 +117,12 @@ func (cli *Client) buildRequest(method, path string, body io.Reader, headers htt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cli *Client) sendRequest(ctx context.Context, method, path string, query url.Values, body io.Reader, headers http.Header) (serverResponse, error) {
|
func (cli *Client) sendRequest(ctx context.Context, method, path string, query url.Values, body io.Reader, headers http.Header) (serverResponse, error) {
|
||||||
req, err := cli.buildRequest(method, cli.getAPIPath(ctx, path, query), body, headers)
|
req, err := cli.buildRequest(ctx, method, cli.getAPIPath(ctx, path, query), body, headers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return serverResponse{}, err
|
return serverResponse{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := cli.doRequest(ctx, req)
|
resp, err := cli.doRequest(req)
|
||||||
switch {
|
switch {
|
||||||
case errors.Is(err, context.Canceled):
|
case errors.Is(err, context.Canceled):
|
||||||
return serverResponse{}, errdefs.Cancelled(err)
|
return serverResponse{}, errdefs.Cancelled(err)
|
||||||
|
@ -134,10 +134,9 @@ func (cli *Client) sendRequest(ctx context.Context, method, path string, query u
|
||||||
return resp, errdefs.FromStatusCode(err, resp.statusCode)
|
return resp, errdefs.FromStatusCode(err, resp.statusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cli *Client) doRequest(ctx context.Context, req *http.Request) (serverResponse, error) {
|
func (cli *Client) doRequest(req *http.Request) (serverResponse, error) {
|
||||||
serverResp := serverResponse{statusCode: -1, reqURL: req.URL}
|
serverResp := serverResponse{statusCode: -1, reqURL: req.URL}
|
||||||
|
|
||||||
req = req.WithContext(ctx)
|
|
||||||
resp, err := cli.client.Do(req)
|
resp, err := cli.client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if cli.scheme != "https" && strings.Contains(err.Error(), "malformed HTTP response") {
|
if cli.scheme != "https" && strings.Contains(err.Error(), "malformed HTTP response") {
|
||||||
|
|
Loading…
Reference in a new issue