diff options
Diffstat (limited to 'libgo/go/net/http/httputil/reverseproxy.go')
-rw-r--r-- | libgo/go/net/http/httputil/reverseproxy.go | 151 |
1 files changed, 94 insertions, 57 deletions
diff --git a/libgo/go/net/http/httputil/reverseproxy.go b/libgo/go/net/http/httputil/reverseproxy.go index 49c120afde1..79c8fe27702 100644 --- a/libgo/go/net/http/httputil/reverseproxy.go +++ b/libgo/go/net/http/httputil/reverseproxy.go @@ -7,6 +7,7 @@ package httputil import ( + "context" "io" "log" "net" @@ -29,6 +30,8 @@ type ReverseProxy struct { // the request into a new request to be sent // using Transport. Its response is then copied // back to the original client unmodified. + // Director must not access the provided Request + // after returning. Director func(*http.Request) // The transport used to perform proxy requests. @@ -51,6 +54,11 @@ type ReverseProxy struct { // get byte slices for use by io.CopyBuffer when // copying HTTP response bodies. BufferPool BufferPool + + // ModifyResponse is an optional function that + // modifies the Response from the backend. + // If it returns an error, the proxy returns a StatusBadGateway error. + ModifyResponse func(*http.Response) error } // A BufferPool is an interface for getting and returning temporary @@ -120,76 +128,59 @@ var hopHeaders = []string{ "Upgrade", } -type requestCanceler interface { - CancelRequest(*http.Request) -} - -type runOnFirstRead struct { - io.Reader // optional; nil means empty body - - fn func() // Run before first Read, then set to nil -} - -func (c *runOnFirstRead) Read(bs []byte) (int, error) { - if c.fn != nil { - c.fn() - c.fn = nil - } - if c.Reader == nil { - return 0, io.EOF - } - return c.Reader.Read(bs) -} - func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { transport := p.Transport if transport == nil { transport = http.DefaultTransport } + ctx := req.Context() + if cn, ok := rw.(http.CloseNotifier); ok { + var cancel context.CancelFunc + ctx, cancel = context.WithCancel(ctx) + defer cancel() + notifyChan := cn.CloseNotify() + go func() { + select { + case <-notifyChan: + cancel() + case <-ctx.Done(): + } + }() + } + outreq := new(http.Request) *outreq = *req // includes shallow copies of maps, but okay - - if closeNotifier, ok := rw.(http.CloseNotifier); ok { - if requestCanceler, ok := transport.(requestCanceler); ok { - reqDone := make(chan struct{}) - defer close(reqDone) - - clientGone := closeNotifier.CloseNotify() - - outreq.Body = struct { - io.Reader - io.Closer - }{ - Reader: &runOnFirstRead{ - Reader: outreq.Body, - fn: func() { - go func() { - select { - case <-clientGone: - requestCanceler.CancelRequest(outreq) - case <-reqDone: - } - }() - }, - }, - Closer: outreq.Body, - } - } + if req.ContentLength == 0 { + outreq.Body = nil // Issue 16036: nil Body for http.Transport retries } + outreq = outreq.WithContext(ctx) p.Director(outreq) - outreq.Proto = "HTTP/1.1" - outreq.ProtoMajor = 1 - outreq.ProtoMinor = 1 outreq.Close = false - // Remove hop-by-hop headers to the backend. Especially - // important is "Connection" because we want a persistent - // connection, regardless of what the client sent to us. This - // is modifying the same underlying map from req (shallow + // We are modifying the same underlying map from req (shallow // copied above) so we only copy it if necessary. copiedHeaders := false + + // Remove hop-by-hop headers listed in the "Connection" header. + // See RFC 2616, section 14.10. + if c := outreq.Header.Get("Connection"); c != "" { + for _, f := range strings.Split(c, ",") { + if f = strings.TrimSpace(f); f != "" { + if !copiedHeaders { + outreq.Header = make(http.Header) + copyHeader(outreq.Header, req.Header) + copiedHeaders = true + } + outreq.Header.Del(f) + } + } + } + + // Remove hop-by-hop headers to the backend. Especially + // important is "Connection" because we want a persistent + // connection, regardless of what the client sent to us. for _, h := range hopHeaders { if outreq.Header.Get(h) != "" { if !copiedHeaders { @@ -218,16 +209,34 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { return } + // Remove hop-by-hop headers listed in the + // "Connection" header of the response. + if c := res.Header.Get("Connection"); c != "" { + for _, f := range strings.Split(c, ",") { + if f = strings.TrimSpace(f); f != "" { + res.Header.Del(f) + } + } + } + for _, h := range hopHeaders { res.Header.Del(h) } + if p.ModifyResponse != nil { + if err := p.ModifyResponse(res); err != nil { + p.logf("http: proxy error: %v", err) + rw.WriteHeader(http.StatusBadGateway) + return + } + } + copyHeader(rw.Header(), res.Header) // The "Trailer" header isn't included in the Transport's response, // at least for *http.Transport. Build it up from Trailer. if len(res.Trailer) > 0 { - var trailerKeys []string + trailerKeys := make([]string, 0, len(res.Trailer)) for k := range res.Trailer { trailerKeys = append(trailerKeys, k) } @@ -266,12 +275,40 @@ func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) { if p.BufferPool != nil { buf = p.BufferPool.Get() } - io.CopyBuffer(dst, src, buf) + p.copyBuffer(dst, src, buf) if p.BufferPool != nil { p.BufferPool.Put(buf) } } +func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) { + if len(buf) == 0 { + buf = make([]byte, 32*1024) + } + var written int64 + for { + nr, rerr := src.Read(buf) + if rerr != nil && rerr != io.EOF { + p.logf("httputil: ReverseProxy read error during body copy: %v", rerr) + } + if nr > 0 { + nw, werr := dst.Write(buf[:nr]) + if nw > 0 { + written += int64(nw) + } + if werr != nil { + return written, werr + } + if nr != nw { + return written, io.ErrShortWrite + } + } + if rerr != nil { + return written, rerr + } + } +} + func (p *ReverseProxy) logf(format string, args ...interface{}) { if p.ErrorLog != nil { p.ErrorLog.Printf(format, args...) |