diff options
Diffstat (limited to 'libgo/go/net/http/transport.go')
-rw-r--r-- | libgo/go/net/http/transport.go | 199 |
1 files changed, 127 insertions, 72 deletions
diff --git a/libgo/go/net/http/transport.go b/libgo/go/net/http/transport.go index ee279877e02..64d8510b959 100644 --- a/libgo/go/net/http/transport.go +++ b/libgo/go/net/http/transport.go @@ -89,7 +89,7 @@ const DefaultMaxIdleConnsPerHost = 2 // Request.GetBody defined. HTTP requests are considered idempotent if // they have HTTP methods GET, HEAD, OPTIONS, or TRACE; or if their // Header map contains an "Idempotency-Key" or "X-Idempotency-Key" -// entry. If the idempotency key value is an zero-length slice, the +// entry. If the idempotency key value is a zero-length slice, the // request is treated as idempotent but the header is not sent on the // wire. type Transport struct { @@ -142,15 +142,24 @@ type Transport struct { // If both are set, DialContext takes priority. Dial func(network, addr string) (net.Conn, error) - // DialTLS specifies an optional dial function for creating + // DialTLSContext specifies an optional dial function for creating // TLS connections for non-proxied HTTPS requests. // - // If DialTLS is nil, Dial and TLSClientConfig are used. + // If DialTLSContext is nil (and the deprecated DialTLS below is also nil), + // DialContext and TLSClientConfig are used. // - // If DialTLS is set, the Dial hook is not used for HTTPS + // If DialTLSContext is set, the Dial and DialContext hooks are not used for HTTPS // requests and the TLSClientConfig and TLSHandshakeTimeout // are ignored. The returned net.Conn is assumed to already be // past the TLS handshake. + DialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error) + + // DialTLS specifies an optional dial function for creating + // TLS connections for non-proxied HTTPS requests. + // + // Deprecated: Use DialTLSContext instead, which allows the transport + // to cancel dials as soon as they are no longer needed. + // If both are set, DialTLSContext takes priority. DialTLS func(network, addr string) (net.Conn, error) // TLSClientConfig specifies the TLS configuration to use with @@ -218,7 +227,7 @@ type Transport struct { ExpectContinueTimeout time.Duration // TLSNextProto specifies how the Transport switches to an - // alternate protocol (such as HTTP/2) after a TLS NPN/ALPN + // alternate protocol (such as HTTP/2) after a TLS ALPN // protocol negotiation. If Transport dials an TLS connection // with a non-empty protocol name and TLSNextProto contains a // map entry for that key (such as "h2"), then the func is @@ -286,7 +295,7 @@ func (t *Transport) Clone() *Transport { DialContext: t.DialContext, Dial: t.Dial, DialTLS: t.DialTLS, - TLSClientConfig: t.TLSClientConfig.Clone(), + DialTLSContext: t.DialTLSContext, TLSHandshakeTimeout: t.TLSHandshakeTimeout, DisableKeepAlives: t.DisableKeepAlives, DisableCompression: t.DisableCompression, @@ -302,6 +311,9 @@ func (t *Transport) Clone() *Transport { WriteBufferSize: t.WriteBufferSize, ReadBufferSize: t.ReadBufferSize, } + if t.TLSClientConfig != nil { + t2.TLSClientConfig = t.TLSClientConfig.Clone() + } if !t.tlsNextProtoWasNil { npm := map[string]func(authority string, c *tls.Conn) RoundTripper{} for k, v := range t.TLSNextProto { @@ -322,6 +334,10 @@ type h2Transport interface { CloseIdleConnections() } +func (t *Transport) hasCustomTLSDialer() bool { + return t.DialTLS != nil || t.DialTLSContext != nil +} + // onceSetNextProtoDefaults initializes TLSNextProto. // It must be called via t.nextProtoOnce.Do. func (t *Transport) onceSetNextProtoDefaults() { @@ -350,7 +366,7 @@ func (t *Transport) onceSetNextProtoDefaults() { // Transport. return } - if !t.ForceAttemptHTTP2 && (t.TLSClientConfig != nil || t.Dial != nil || t.DialTLS != nil || t.DialContext != nil) { + if !t.ForceAttemptHTTP2 && (t.TLSClientConfig != nil || t.Dial != nil || t.DialContext != nil || t.hasCustomTLSDialer()) { // Be conservative and don't automatically enable // http2 if they've specified a custom TLS config or // custom dialers. Let them opt-in themselves via @@ -359,6 +375,9 @@ func (t *Transport) onceSetNextProtoDefaults() { // However, if ForceAttemptHTTP2 is true, it overrides the above checks. return } + if omitBundledHTTP2 { + return + } t2, err := http2configureTransport(t) if err != nil { log.Printf("Error enabling Transport HTTP/2 support: %v", err) @@ -437,7 +456,7 @@ func (tr *transportRequest) setError(err error) { tr.mu.Unlock() } -// useRegisteredProtocol reports whether an alternate protocol (as reqistered +// useRegisteredProtocol reports whether an alternate protocol (as registered // with Transport.RegisterProtocol) should be respected for this request. func (t *Transport) useRegisteredProtocol(req *Request) bool { if req.URL.Scheme == "https" && req.requiresHTTP1() { @@ -469,10 +488,12 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) { if isHTTP { for k, vv := range req.Header { if !httpguts.ValidHeaderFieldName(k) { + req.closeBody() return nil, fmt.Errorf("net/http: invalid header field name %q", k) } for _, v := range vv { if !httpguts.ValidHeaderFieldValue(v) { + req.closeBody() return nil, fmt.Errorf("net/http: invalid header field value %q for key %v", v, k) } } @@ -492,6 +513,7 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) { return nil, &badStringError{"unsupported protocol scheme", scheme} } if req.Method != "" && !validMethod(req.Method) { + req.closeBody() return nil, fmt.Errorf("net/http: invalid method %q", req.Method) } if req.URL.Host == "" { @@ -537,9 +559,16 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) { if err == nil { return resp, nil } - if http2isNoCachedConnError(err) { - t.removeIdleConn(pconn) - } else if !pconn.shouldRetryRequest(req, err) { + + // Failed. Clean up and determine whether to retry. + + _, isH2DialError := pconn.alt.(http2erringRoundTripper) + if http2isNoCachedConnError(err) || isH2DialError { + if t.removeIdleConn(pconn) { + t.decConnsPerHost(pconn.cacheKey) + } + } + if !pconn.shouldRetryRequest(req, err) { // Issue 16465: return underlying net.Conn.Read error from peek, // as we've historically done. if e, ok := err.(transportReadFromServerError); ok { @@ -710,20 +739,10 @@ func resetProxyConfig() { } func (t *Transport) connectMethodForRequest(treq *transportRequest) (cm connectMethod, err error) { - // TODO: the validPort check is redundant after CL 189258, as url.URL.Port - // only returns valid ports now. golang.org/issue/33600 - if port := treq.URL.Port(); !validPort(port) { - return cm, fmt.Errorf("invalid URL port %q", port) - } cm.targetScheme = treq.URL.Scheme cm.targetAddr = canonicalAddr(treq.URL) if t.Proxy != nil { cm.proxyURL, err = t.Proxy(treq.Request) - if err == nil && cm.proxyURL != nil { - if port := cm.proxyURL.Port(); !validPort(port) { - return cm, fmt.Errorf("invalid proxy URL port %q", port) - } - } } cm.onlyH1 = treq.requiresHTTP1() return cm, err @@ -753,7 +772,6 @@ var ( errCloseIdleConns = errors.New("http: CloseIdleConnections called") errReadLoopExiting = errors.New("http: persistConn.readLoop exiting") errIdleConnTimeout = errors.New("http: idle connection timeout") - errNotCachingH2Conn = errors.New("http: not caching alternate protocol's connections") // errServerClosedIdle is not seen by users for idempotent requests, but may be // seen by a user if the server shuts down an idle connection and sends its FIN @@ -911,16 +929,37 @@ func (t *Transport) queueForIdleConn(w *wantConn) (delivered bool) { return false } + // If IdleConnTimeout is set, calculate the oldest + // persistConn.idleAt time we're willing to use a cached idle + // conn. + var oldTime time.Time + if t.IdleConnTimeout > 0 { + oldTime = time.Now().Add(-t.IdleConnTimeout) + } + // Look for most recently-used idle connection. if list, ok := t.idleConn[w.key]; ok { stop := false delivered := false for len(list) > 0 && !stop { pconn := list[len(list)-1] - if pconn.isBroken() { - // persistConn.readLoop has marked the connection broken, - // but Transport.removeIdleConn has not yet removed it from the idle list. - // Drop on floor on behalf of Transport.removeIdleConn. + + // See whether this connection has been idle too long, considering + // only the wall time (the Round(0)), in case this is a laptop or VM + // coming out of suspend with previously cached idle connections. + tooOld := !oldTime.IsZero() && pconn.idleAt.Round(0).Before(oldTime) + if tooOld { + // Async cleanup. Launch in its own goroutine (as if a + // time.AfterFunc called it); it acquires idleMu, which we're + // holding, and does a synchronous net.Conn.Close. + go pconn.closeConnIfStillIdle() + } + if pconn.isBroken() || tooOld { + // If either persistConn.readLoop has marked the connection + // broken, but Transport.removeIdleConn has not yet removed it + // from the idle list, or if this persistConn is too old (it was + // idle too long), then ignore it and look for another. In both + // cases it's already in the process of being closed. list = list[:len(list)-1] continue } @@ -960,26 +999,28 @@ func (t *Transport) queueForIdleConn(w *wantConn) (delivered bool) { } // removeIdleConn marks pconn as dead. -func (t *Transport) removeIdleConn(pconn *persistConn) { +func (t *Transport) removeIdleConn(pconn *persistConn) bool { t.idleMu.Lock() defer t.idleMu.Unlock() - t.removeIdleConnLocked(pconn) + return t.removeIdleConnLocked(pconn) } // t.idleMu must be held. -func (t *Transport) removeIdleConnLocked(pconn *persistConn) { +func (t *Transport) removeIdleConnLocked(pconn *persistConn) bool { if pconn.idleTimer != nil { pconn.idleTimer.Stop() } t.idleLRU.remove(pconn) key := pconn.cacheKey pconns := t.idleConn[key] + var removed bool switch len(pconns) { case 0: // Nothing case 1: if pconns[0] == pconn { delete(t.idleConn, key) + removed = true } default: for i, v := range pconns { @@ -990,9 +1031,11 @@ func (t *Transport) removeIdleConnLocked(pconn *persistConn) { // conns at the end. copy(pconns[i:], pconns[i+1:]) t.idleConn[key] = pconns[:len(pconns)-1] + removed = true break } } + return removed } func (t *Transport) setReqCanceler(r *Request, fn func(error)) { @@ -1177,6 +1220,18 @@ func (q *wantConnQueue) cleanFront() (cleaned bool) { } } +func (t *Transport) customDialTLS(ctx context.Context, network, addr string) (conn net.Conn, err error) { + if t.DialTLSContext != nil { + conn, err = t.DialTLSContext(ctx, network, addr) + } else { + conn, err = t.DialTLS(network, addr) + } + if conn == nil && err == nil { + err = errors.New("net/http: Transport.DialTLS or DialTLSContext returned (nil, nil)") + } + return +} + // getConn dials and creates a new persistConn to the target as // specified in the connectMethod. This includes doing a proxy CONNECT // and/or setting up TLS. If this doesn't return an error, the persistConn @@ -1206,7 +1261,9 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persi // Queue for idle connection. if delivered := t.queueForIdleConn(w); delivered { pc := w.pc - if trace != nil && trace.GotConn != nil { + // Trace only for HTTP/1. + // HTTP/2 calls trace.GotConn itself. + if pc.alt == nil && trace != nil && trace.GotConn != nil { trace.GotConn(pc.gotIdleConnTrace(pc.idleAt)) } // set request canceler to some non-nil function so we @@ -1360,19 +1417,6 @@ func (t *Transport) decConnsPerHost(key connectMethodKey) { } } -// The connect method and the transport can both specify a TLS -// Host name. The transport's name takes precedence if present. -func chooseTLSHost(cm connectMethod, t *Transport) string { - tlsHost := "" - if t.TLSClientConfig != nil { - tlsHost = t.TLSClientConfig.ServerName - } - if tlsHost == "" { - tlsHost = cm.tlsHost() - } - return tlsHost -} - // Add TLS to a persistent connection, i.e. negotiate a TLS session. If pconn is already a TLS // tunnel, this function establishes a nested TLS session inside the encrypted channel. // The remote endpoint's name may be overridden by TLSClientConfig.ServerName. @@ -1438,15 +1482,12 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers } return err } - if cm.scheme() == "https" && t.DialTLS != nil { + if cm.scheme() == "https" && t.hasCustomTLSDialer() { var err error - pconn.conn, err = t.DialTLS("tcp", cm.addr()) + pconn.conn, err = t.customDialTLS(ctx, "tcp", cm.addr()) if err != nil { return nil, wrapErr(err) } - if pconn.conn == nil { - return nil, wrapErr(errors.New("net/http: Transport.DialTLS returned (nil, nil)")) - } if tc, ok := pconn.conn.(*tls.Conn); ok { // Handshake here, in case DialTLS didn't. TLSNextProto below // depends on it for knowing the connection state. @@ -1527,13 +1568,44 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers if pa := cm.proxyAuth(); pa != "" { connectReq.Header.Set("Proxy-Authorization", pa) } - connectReq.Write(conn) - // Read response. - // Okay to use and discard buffered reader here, because - // TLS server will not speak until spoken to. - br := bufio.NewReader(conn) - resp, err := ReadResponse(br, connectReq) + // If there's no done channel (no deadline or cancellation + // from the caller possible), at least set some (long) + // timeout here. This will make sure we don't block forever + // and leak a goroutine if the connection stops replying + // after the TCP connect. + connectCtx := ctx + if ctx.Done() == nil { + newCtx, cancel := context.WithTimeout(ctx, 1*time.Minute) + defer cancel() + connectCtx = newCtx + } + + didReadResponse := make(chan struct{}) // closed after CONNECT write+read is done or fails + var ( + resp *Response + err error // write or read error + ) + // Write the CONNECT request & read the response. + go func() { + defer close(didReadResponse) + err = connectReq.Write(conn) + if err != nil { + return + } + // Okay to use and discard buffered reader here, because + // TLS server will not speak until spoken to. + br := bufio.NewReader(conn) + resp, err = ReadResponse(br, connectReq) + }() + select { + case <-connectCtx.Done(): + conn.Close() + <-didReadResponse + return nil, connectCtx.Err() + case <-didReadResponse: + // resp or err now set + } if err != nil { conn.Close() return nil, err @@ -1927,7 +1999,7 @@ func (pc *persistConn) readLoop() { } return } - pc.readLimit = maxInt64 // effictively no limit for response bodies + pc.readLimit = maxInt64 // effectively no limit for response bodies pc.mu.Lock() pc.numExpectedResponses-- @@ -2635,11 +2707,6 @@ func (gz *gzipReader) Close() error { return gz.body.Close() } -type readerAndCloser struct { - io.Reader - io.Closer -} - type tlsHandshakeTimeoutError struct{} func (tlsHandshakeTimeoutError) Timeout() bool { return true } @@ -2702,15 +2769,3 @@ func (cl *connLRU) remove(pc *persistConn) { func (cl *connLRU) len() int { return len(cl.m) } - -// validPort reports whether p (without the colon) is a valid port in -// a URL, per RFC 3986 Section 3.2.3, which says the port may be -// empty, or only contain digits. -func validPort(p string) bool { - for _, r := range []byte(p) { - if r < '0' || r > '9' { - return false - } - } - return true -} |