|
@@ -10,6 +10,7 @@ import (
|
|
|
"bufio"
|
|
|
"bytes"
|
|
|
"compress/gzip"
|
|
|
+ "context"
|
|
|
"crypto/rand"
|
|
|
"crypto/tls"
|
|
|
"errors"
|
|
@@ -21,6 +22,7 @@ import (
|
|
|
mathrand "math/rand"
|
|
|
"net"
|
|
|
"net/http"
|
|
|
+ "net/http/httptrace"
|
|
|
"net/textproto"
|
|
|
"sort"
|
|
|
"strconv"
|
|
@@ -95,6 +97,16 @@ type Transport struct {
|
|
|
// to mean no limit.
|
|
|
MaxHeaderListSize uint32
|
|
|
|
|
|
+ // StrictMaxConcurrentStreams controls whether the server's
|
|
|
+ // SETTINGS_MAX_CONCURRENT_STREAMS should be respected
|
|
|
+ // globally. If false, new TCP connections are created to the
|
|
|
+ // server as needed to keep each under the per-connection
|
|
|
+ // SETTINGS_MAX_CONCURRENT_STREAMS limit. If true, the
|
|
|
+ // server's SETTINGS_MAX_CONCURRENT_STREAMS is interpreted as
|
|
|
+ // a global limit and callers of RoundTrip block when needed,
|
|
|
+ // waiting for their turn.
|
|
|
+ StrictMaxConcurrentStreams bool
|
|
|
+
|
|
|
// t1, if non-nil, is the standard library Transport using
|
|
|
// this transport. Its settings are used (but not its
|
|
|
// RoundTrip method, etc).
|
|
@@ -118,16 +130,56 @@ func (t *Transport) disableCompression() bool {
|
|
|
return t.DisableCompression || (t.t1 != nil && t.t1.DisableCompression)
|
|
|
}
|
|
|
|
|
|
-var errTransportVersion = errors.New("http2: ConfigureTransport is only supported starting at Go 1.6")
|
|
|
-
|
|
|
// ConfigureTransport configures a net/http HTTP/1 Transport to use HTTP/2.
|
|
|
-// It requires Go 1.6 or later and returns an error if the net/http package is too old
|
|
|
-// or if t1 has already been HTTP/2-enabled.
|
|
|
+// It returns an error if t1 has already been HTTP/2-enabled.
|
|
|
func ConfigureTransport(t1 *http.Transport) error {
|
|
|
- _, err := configureTransport(t1) // in configure_transport.go (go1.6) or not_go16.go
|
|
|
+ _, err := configureTransport(t1)
|
|
|
return err
|
|
|
}
|
|
|
|
|
|
+func configureTransport(t1 *http.Transport) (*Transport, error) {
|
|
|
+ connPool := new(clientConnPool)
|
|
|
+ t2 := &Transport{
|
|
|
+ ConnPool: noDialClientConnPool{connPool},
|
|
|
+ t1: t1,
|
|
|
+ }
|
|
|
+ connPool.t = t2
|
|
|
+ if err := registerHTTPSProtocol(t1, noDialH2RoundTripper{t2}); err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ if t1.TLSClientConfig == nil {
|
|
|
+ t1.TLSClientConfig = new(tls.Config)
|
|
|
+ }
|
|
|
+ if !strSliceContains(t1.TLSClientConfig.NextProtos, "h2") {
|
|
|
+ t1.TLSClientConfig.NextProtos = append([]string{"h2"}, t1.TLSClientConfig.NextProtos...)
|
|
|
+ }
|
|
|
+ if !strSliceContains(t1.TLSClientConfig.NextProtos, "http/1.1") {
|
|
|
+ t1.TLSClientConfig.NextProtos = append(t1.TLSClientConfig.NextProtos, "http/1.1")
|
|
|
+ }
|
|
|
+ upgradeFn := func(authority string, c *tls.Conn) http.RoundTripper {
|
|
|
+ addr := authorityAddr("https", authority)
|
|
|
+ if used, err := connPool.addConnIfNeeded(addr, t2, c); err != nil {
|
|
|
+ go c.Close()
|
|
|
+ return erringRoundTripper{err}
|
|
|
+ } else if !used {
|
|
|
+ // Turns out we don't need this c.
|
|
|
+ // For example, two goroutines made requests to the same host
|
|
|
+ // at the same time, both kicking off TCP dials. (since protocol
|
|
|
+ // was unknown)
|
|
|
+ go c.Close()
|
|
|
+ }
|
|
|
+ return t2
|
|
|
+ }
|
|
|
+ if m := t1.TLSNextProto; len(m) == 0 {
|
|
|
+ t1.TLSNextProto = map[string]func(string, *tls.Conn) http.RoundTripper{
|
|
|
+ "h2": upgradeFn,
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ m["h2"] = upgradeFn
|
|
|
+ }
|
|
|
+ return t2, nil
|
|
|
+}
|
|
|
+
|
|
|
func (t *Transport) connPool() ClientConnPool {
|
|
|
t.connPoolOnce.Do(t.initConnPool)
|
|
|
return t.connPoolOrDef
|
|
@@ -192,7 +244,7 @@ type ClientConn struct {
|
|
|
type clientStream struct {
|
|
|
cc *ClientConn
|
|
|
req *http.Request
|
|
|
- trace *clientTrace // or nil
|
|
|
+ trace *httptrace.ClientTrace // or nil
|
|
|
ID uint32
|
|
|
resc chan resAndError
|
|
|
bufPipe pipe // buffered pipe with the flow-controlled response payload
|
|
@@ -226,7 +278,7 @@ type clientStream struct {
|
|
|
// channel to be signaled. A non-nil error is returned only if the request was
|
|
|
// canceled.
|
|
|
func awaitRequestCancel(req *http.Request, done <-chan struct{}) error {
|
|
|
- ctx := reqContext(req)
|
|
|
+ ctx := req.Context()
|
|
|
if req.Cancel == nil && ctx.Done() == nil {
|
|
|
return nil
|
|
|
}
|
|
@@ -401,8 +453,8 @@ func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Res
|
|
|
select {
|
|
|
case <-time.After(time.Second * time.Duration(backoff)):
|
|
|
continue
|
|
|
- case <-reqContext(req).Done():
|
|
|
- return nil, reqContext(req).Err()
|
|
|
+ case <-req.Context().Done():
|
|
|
+ return nil, req.Context().Err()
|
|
|
}
|
|
|
}
|
|
|
}
|
|
@@ -439,16 +491,15 @@ func shouldRetryRequest(req *http.Request, err error, afterBodyWrite bool) (*htt
|
|
|
}
|
|
|
// If the Body is nil (or http.NoBody), it's safe to reuse
|
|
|
// this request and its Body.
|
|
|
- if req.Body == nil || reqBodyIsNoBody(req.Body) {
|
|
|
+ if req.Body == nil || req.Body == http.NoBody {
|
|
|
return req, nil
|
|
|
}
|
|
|
|
|
|
// If the request body can be reset back to its original
|
|
|
// state via the optional req.GetBody, do that.
|
|
|
- getBody := reqGetBody(req) // Go 1.8: getBody = req.GetBody
|
|
|
- if getBody != nil {
|
|
|
+ if req.GetBody != nil {
|
|
|
// TODO: consider a req.Body.Close here? or audit that all caller paths do?
|
|
|
- body, err := getBody()
|
|
|
+ body, err := req.GetBody()
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
@@ -494,7 +545,7 @@ func (t *Transport) dialClientConn(addr string, singleUse bool) (*ClientConn, er
|
|
|
func (t *Transport) newTLSConfig(host string) *tls.Config {
|
|
|
cfg := new(tls.Config)
|
|
|
if t.TLSClientConfig != nil {
|
|
|
- *cfg = *cloneTLSConfig(t.TLSClientConfig)
|
|
|
+ *cfg = *t.TLSClientConfig.Clone()
|
|
|
}
|
|
|
if !strSliceContains(cfg.NextProtos, NextProtoTLS) {
|
|
|
cfg.NextProtos = append([]string{NextProtoTLS}, cfg.NextProtos...)
|
|
@@ -545,7 +596,7 @@ func (t *Transport) expectContinueTimeout() time.Duration {
|
|
|
if t.t1 == nil {
|
|
|
return 0
|
|
|
}
|
|
|
- return transportExpectContinueTimeout(t.t1)
|
|
|
+ return t.t1.ExpectContinueTimeout
|
|
|
}
|
|
|
|
|
|
func (t *Transport) NewClientConn(c net.Conn) (*ClientConn, error) {
|
|
@@ -670,8 +721,19 @@ func (cc *ClientConn) idleStateLocked() (st clientConnIdleState) {
|
|
|
if cc.singleUse && cc.nextStreamID > 1 {
|
|
|
return
|
|
|
}
|
|
|
- st.canTakeNewRequest = cc.goAway == nil && !cc.closed && !cc.closing &&
|
|
|
- int64(cc.nextStreamID)+int64(cc.pendingRequests) < math.MaxInt32
|
|
|
+ var maxConcurrentOkay bool
|
|
|
+ if cc.t.StrictMaxConcurrentStreams {
|
|
|
+ // We'll tell the caller we can take a new request to
|
|
|
+ // prevent the caller from dialing a new TCP
|
|
|
+ // connection, but then we'll block later before
|
|
|
+ // writing it.
|
|
|
+ maxConcurrentOkay = true
|
|
|
+ } else {
|
|
|
+ maxConcurrentOkay = int64(len(cc.streams)+1) < int64(cc.maxConcurrentStreams)
|
|
|
+ }
|
|
|
+
|
|
|
+ st.canTakeNewRequest = cc.goAway == nil && !cc.closed && !cc.closing && maxConcurrentOkay &&
|
|
|
+ int64(cc.nextStreamID)+2*int64(cc.pendingRequests) < math.MaxInt32
|
|
|
st.freshConn = cc.nextStreamID == 1 && st.canTakeNewRequest
|
|
|
return
|
|
|
}
|
|
@@ -711,8 +773,7 @@ func (cc *ClientConn) closeIfIdle() {
|
|
|
var shutdownEnterWaitStateHook = func() {}
|
|
|
|
|
|
// Shutdown gracefully close the client connection, waiting for running streams to complete.
|
|
|
-// Public implementation is in go17.go and not_go17.go
|
|
|
-func (cc *ClientConn) shutdown(ctx contextContext) error {
|
|
|
+func (cc *ClientConn) Shutdown(ctx context.Context) error {
|
|
|
if err := cc.sendGoAway(); err != nil {
|
|
|
return err
|
|
|
}
|
|
@@ -882,7 +943,7 @@ func checkConnHeaders(req *http.Request) error {
|
|
|
// req.ContentLength, where 0 actually means zero (not unknown) and -1
|
|
|
// means unknown.
|
|
|
func actualContentLength(req *http.Request) int64 {
|
|
|
- if req.Body == nil || reqBodyIsNoBody(req.Body) {
|
|
|
+ if req.Body == nil || req.Body == http.NoBody {
|
|
|
return 0
|
|
|
}
|
|
|
if req.ContentLength != 0 {
|
|
@@ -952,7 +1013,7 @@ func (cc *ClientConn) roundTrip(req *http.Request) (res *http.Response, gotErrAf
|
|
|
|
|
|
cs := cc.newStream()
|
|
|
cs.req = req
|
|
|
- cs.trace = requestTrace(req)
|
|
|
+ cs.trace = httptrace.ContextClientTrace(req.Context())
|
|
|
cs.requestedGzip = requestedGzip
|
|
|
bodyWriter := cc.t.getBodyWriterState(cs, body)
|
|
|
cs.on100 = bodyWriter.on100
|
|
@@ -990,7 +1051,7 @@ func (cc *ClientConn) roundTrip(req *http.Request) (res *http.Response, gotErrAf
|
|
|
|
|
|
readLoopResCh := cs.resc
|
|
|
bodyWritten := false
|
|
|
- ctx := reqContext(req)
|
|
|
+ ctx := req.Context()
|
|
|
|
|
|
handleReadLoopResponse := func(re resAndError) (*http.Response, bool, error) {
|
|
|
res := re.res
|
|
@@ -1060,6 +1121,7 @@ func (cc *ClientConn) roundTrip(req *http.Request) (res *http.Response, gotErrAf
|
|
|
default:
|
|
|
}
|
|
|
if err != nil {
|
|
|
+ cc.forgetStreamID(cs.ID)
|
|
|
return nil, cs.getStartedWrite(), err
|
|
|
}
|
|
|
bodyWritten = true
|
|
@@ -1181,6 +1243,7 @@ func (cs *clientStream) writeRequestBody(body io.Reader, bodyCloser io.Closer) (
|
|
|
sawEOF = true
|
|
|
err = nil
|
|
|
} else if err != nil {
|
|
|
+ cc.writeStreamReset(cs.ID, ErrCodeCancel, err)
|
|
|
return err
|
|
|
}
|
|
|
|
|
@@ -1348,7 +1411,11 @@ func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trail
|
|
|
// followed by the query production (see Sections 3.3 and 3.4 of
|
|
|
// [RFC3986]).
|
|
|
f(":authority", host)
|
|
|
- f(":method", req.Method)
|
|
|
+ m := req.Method
|
|
|
+ if m == "" {
|
|
|
+ m = http.MethodGet
|
|
|
+ }
|
|
|
+ f(":method", m)
|
|
|
if req.Method != "CONNECT" {
|
|
|
f(":path", path)
|
|
|
f(":scheme", req.URL.Scheme)
|
|
@@ -1416,7 +1483,7 @@ func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trail
|
|
|
return nil, errRequestHeaderListSize
|
|
|
}
|
|
|
|
|
|
- trace := requestTrace(req)
|
|
|
+ trace := httptrace.ContextClientTrace(req.Context())
|
|
|
traceHeaders := traceHasWroteHeaderField(trace)
|
|
|
|
|
|
// Header list size is ok. Write the headers.
|
|
@@ -1839,7 +1906,7 @@ func (rl *clientConnReadLoop) handleResponse(cs *clientStream, f *MetaHeadersFra
|
|
|
res.Header.Del("Content-Length")
|
|
|
res.ContentLength = -1
|
|
|
res.Body = &gzipReader{body: res.Body}
|
|
|
- setResponseUncompressed(res)
|
|
|
+ res.Uncompressed = true
|
|
|
}
|
|
|
return res, nil
|
|
|
}
|
|
@@ -2216,8 +2283,7 @@ func (rl *clientConnReadLoop) processResetStream(f *RSTStreamFrame) error {
|
|
|
}
|
|
|
|
|
|
// Ping sends a PING frame to the server and waits for the ack.
|
|
|
-// Public implementation is in go17.go and not_go17.go
|
|
|
-func (cc *ClientConn) ping(ctx contextContext) error {
|
|
|
+func (cc *ClientConn) Ping(ctx context.Context) error {
|
|
|
c := make(chan struct{})
|
|
|
// Generate a random payload
|
|
|
var p [8]byte
|
|
@@ -2451,3 +2517,91 @@ func (s bodyWriterState) scheduleBodyWrite() {
|
|
|
func isConnectionCloseRequest(req *http.Request) bool {
|
|
|
return req.Close || httpguts.HeaderValuesContainsToken(req.Header["Connection"], "close")
|
|
|
}
|
|
|
+
|
|
|
+// registerHTTPSProtocol calls Transport.RegisterProtocol but
|
|
|
+// converting panics into errors.
|
|
|
+func registerHTTPSProtocol(t *http.Transport, rt noDialH2RoundTripper) (err error) {
|
|
|
+ defer func() {
|
|
|
+ if e := recover(); e != nil {
|
|
|
+ err = fmt.Errorf("%v", e)
|
|
|
+ }
|
|
|
+ }()
|
|
|
+ t.RegisterProtocol("https", rt)
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+// noDialH2RoundTripper is a RoundTripper which only tries to complete the request
|
|
|
+// if there's already has a cached connection to the host.
|
|
|
+// (The field is exported so it can be accessed via reflect from net/http; tested
|
|
|
+// by TestNoDialH2RoundTripperType)
|
|
|
+type noDialH2RoundTripper struct{ *Transport }
|
|
|
+
|
|
|
+func (rt noDialH2RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
|
+ res, err := rt.Transport.RoundTrip(req)
|
|
|
+ if isNoCachedConnError(err) {
|
|
|
+ return nil, http.ErrSkipAltProtocol
|
|
|
+ }
|
|
|
+ return res, err
|
|
|
+}
|
|
|
+
|
|
|
+func (t *Transport) idleConnTimeout() time.Duration {
|
|
|
+ if t.t1 != nil {
|
|
|
+ return t.t1.IdleConnTimeout
|
|
|
+ }
|
|
|
+ return 0
|
|
|
+}
|
|
|
+
|
|
|
+func traceGetConn(req *http.Request, hostPort string) {
|
|
|
+ trace := httptrace.ContextClientTrace(req.Context())
|
|
|
+ if trace == nil || trace.GetConn == nil {
|
|
|
+ return
|
|
|
+ }
|
|
|
+ trace.GetConn(hostPort)
|
|
|
+}
|
|
|
+
|
|
|
+func traceGotConn(req *http.Request, cc *ClientConn) {
|
|
|
+ trace := httptrace.ContextClientTrace(req.Context())
|
|
|
+ if trace == nil || trace.GotConn == nil {
|
|
|
+ return
|
|
|
+ }
|
|
|
+ ci := httptrace.GotConnInfo{Conn: cc.tconn}
|
|
|
+ cc.mu.Lock()
|
|
|
+ ci.Reused = cc.nextStreamID > 1
|
|
|
+ ci.WasIdle = len(cc.streams) == 0 && ci.Reused
|
|
|
+ if ci.WasIdle && !cc.lastActive.IsZero() {
|
|
|
+ ci.IdleTime = time.Now().Sub(cc.lastActive)
|
|
|
+ }
|
|
|
+ cc.mu.Unlock()
|
|
|
+
|
|
|
+ trace.GotConn(ci)
|
|
|
+}
|
|
|
+
|
|
|
+func traceWroteHeaders(trace *httptrace.ClientTrace) {
|
|
|
+ if trace != nil && trace.WroteHeaders != nil {
|
|
|
+ trace.WroteHeaders()
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func traceGot100Continue(trace *httptrace.ClientTrace) {
|
|
|
+ if trace != nil && trace.Got100Continue != nil {
|
|
|
+ trace.Got100Continue()
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func traceWait100Continue(trace *httptrace.ClientTrace) {
|
|
|
+ if trace != nil && trace.Wait100Continue != nil {
|
|
|
+ trace.Wait100Continue()
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func traceWroteRequest(trace *httptrace.ClientTrace, err error) {
|
|
|
+ if trace != nil && trace.WroteRequest != nil {
|
|
|
+ trace.WroteRequest(httptrace.WroteRequestInfo{Err: err})
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func traceFirstResponseByte(trace *httptrace.ClientTrace) {
|
|
|
+ if trace != nil && trace.GotFirstResponseByte != nil {
|
|
|
+ trace.GotFirstResponseByte()
|
|
|
+ }
|
|
|
+}
|