summaryrefslogtreecommitdiff
path: root/libgo/go/net/http/httputil/reverseproxy.go
diff options
context:
space:
mode:
Diffstat (limited to 'libgo/go/net/http/httputil/reverseproxy.go')
-rw-r--r--libgo/go/net/http/httputil/reverseproxy.go151
1 files changed, 94 insertions, 57 deletions
diff --git a/libgo/go/net/http/httputil/reverseproxy.go b/libgo/go/net/http/httputil/reverseproxy.go
index 49c120afde1..79c8fe27702 100644
--- a/libgo/go/net/http/httputil/reverseproxy.go
+++ b/libgo/go/net/http/httputil/reverseproxy.go
@@ -7,6 +7,7 @@
package httputil
import (
+ "context"
"io"
"log"
"net"
@@ -29,6 +30,8 @@ type ReverseProxy struct {
// the request into a new request to be sent
// using Transport. Its response is then copied
// back to the original client unmodified.
+ // Director must not access the provided Request
+ // after returning.
Director func(*http.Request)
// The transport used to perform proxy requests.
@@ -51,6 +54,11 @@ type ReverseProxy struct {
// get byte slices for use by io.CopyBuffer when
// copying HTTP response bodies.
BufferPool BufferPool
+
+ // ModifyResponse is an optional function that
+ // modifies the Response from the backend.
+ // If it returns an error, the proxy returns a StatusBadGateway error.
+ ModifyResponse func(*http.Response) error
}
// A BufferPool is an interface for getting and returning temporary
@@ -120,76 +128,59 @@ var hopHeaders = []string{
"Upgrade",
}
-type requestCanceler interface {
- CancelRequest(*http.Request)
-}
-
-type runOnFirstRead struct {
- io.Reader // optional; nil means empty body
-
- fn func() // Run before first Read, then set to nil
-}
-
-func (c *runOnFirstRead) Read(bs []byte) (int, error) {
- if c.fn != nil {
- c.fn()
- c.fn = nil
- }
- if c.Reader == nil {
- return 0, io.EOF
- }
- return c.Reader.Read(bs)
-}
-
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
transport := p.Transport
if transport == nil {
transport = http.DefaultTransport
}
+ ctx := req.Context()
+ if cn, ok := rw.(http.CloseNotifier); ok {
+ var cancel context.CancelFunc
+ ctx, cancel = context.WithCancel(ctx)
+ defer cancel()
+ notifyChan := cn.CloseNotify()
+ go func() {
+ select {
+ case <-notifyChan:
+ cancel()
+ case <-ctx.Done():
+ }
+ }()
+ }
+
outreq := new(http.Request)
*outreq = *req // includes shallow copies of maps, but okay
-
- if closeNotifier, ok := rw.(http.CloseNotifier); ok {
- if requestCanceler, ok := transport.(requestCanceler); ok {
- reqDone := make(chan struct{})
- defer close(reqDone)
-
- clientGone := closeNotifier.CloseNotify()
-
- outreq.Body = struct {
- io.Reader
- io.Closer
- }{
- Reader: &runOnFirstRead{
- Reader: outreq.Body,
- fn: func() {
- go func() {
- select {
- case <-clientGone:
- requestCanceler.CancelRequest(outreq)
- case <-reqDone:
- }
- }()
- },
- },
- Closer: outreq.Body,
- }
- }
+ if req.ContentLength == 0 {
+ outreq.Body = nil // Issue 16036: nil Body for http.Transport retries
}
+ outreq = outreq.WithContext(ctx)
p.Director(outreq)
- outreq.Proto = "HTTP/1.1"
- outreq.ProtoMajor = 1
- outreq.ProtoMinor = 1
outreq.Close = false
- // Remove hop-by-hop headers to the backend. Especially
- // important is "Connection" because we want a persistent
- // connection, regardless of what the client sent to us. This
- // is modifying the same underlying map from req (shallow
+ // We are modifying the same underlying map from req (shallow
// copied above) so we only copy it if necessary.
copiedHeaders := false
+
+ // Remove hop-by-hop headers listed in the "Connection" header.
+ // See RFC 2616, section 14.10.
+ if c := outreq.Header.Get("Connection"); c != "" {
+ for _, f := range strings.Split(c, ",") {
+ if f = strings.TrimSpace(f); f != "" {
+ if !copiedHeaders {
+ outreq.Header = make(http.Header)
+ copyHeader(outreq.Header, req.Header)
+ copiedHeaders = true
+ }
+ outreq.Header.Del(f)
+ }
+ }
+ }
+
+ // Remove hop-by-hop headers to the backend. Especially
+ // important is "Connection" because we want a persistent
+ // connection, regardless of what the client sent to us.
for _, h := range hopHeaders {
if outreq.Header.Get(h) != "" {
if !copiedHeaders {
@@ -218,16 +209,34 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
return
}
+ // Remove hop-by-hop headers listed in the
+ // "Connection" header of the response.
+ if c := res.Header.Get("Connection"); c != "" {
+ for _, f := range strings.Split(c, ",") {
+ if f = strings.TrimSpace(f); f != "" {
+ res.Header.Del(f)
+ }
+ }
+ }
+
for _, h := range hopHeaders {
res.Header.Del(h)
}
+ if p.ModifyResponse != nil {
+ if err := p.ModifyResponse(res); err != nil {
+ p.logf("http: proxy error: %v", err)
+ rw.WriteHeader(http.StatusBadGateway)
+ return
+ }
+ }
+
copyHeader(rw.Header(), res.Header)
// The "Trailer" header isn't included in the Transport's response,
// at least for *http.Transport. Build it up from Trailer.
if len(res.Trailer) > 0 {
- var trailerKeys []string
+ trailerKeys := make([]string, 0, len(res.Trailer))
for k := range res.Trailer {
trailerKeys = append(trailerKeys, k)
}
@@ -266,12 +275,40 @@ func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) {
if p.BufferPool != nil {
buf = p.BufferPool.Get()
}
- io.CopyBuffer(dst, src, buf)
+ p.copyBuffer(dst, src, buf)
if p.BufferPool != nil {
p.BufferPool.Put(buf)
}
}
+func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) {
+ if len(buf) == 0 {
+ buf = make([]byte, 32*1024)
+ }
+ var written int64
+ for {
+ nr, rerr := src.Read(buf)
+ if rerr != nil && rerr != io.EOF {
+ p.logf("httputil: ReverseProxy read error during body copy: %v", rerr)
+ }
+ if nr > 0 {
+ nw, werr := dst.Write(buf[:nr])
+ if nw > 0 {
+ written += int64(nw)
+ }
+ if werr != nil {
+ return written, werr
+ }
+ if nr != nw {
+ return written, io.ErrShortWrite
+ }
+ }
+ if rerr != nil {
+ return written, rerr
+ }
+ }
+}
+
func (p *ReverseProxy) logf(format string, args ...interface{}) {
if p.ErrorLog != nil {
p.ErrorLog.Printf(format, args...)