diff options
Diffstat (limited to 'libgo/go/net/http/transport_test.go')
-rw-r--r-- | libgo/go/net/http/transport_test.go | 407 |
1 files changed, 398 insertions, 9 deletions
diff --git a/libgo/go/net/http/transport_test.go b/libgo/go/net/http/transport_test.go index 6e075847dde..2b58e1daecb 100644 --- a/libgo/go/net/http/transport_test.go +++ b/libgo/go/net/http/transport_test.go @@ -20,6 +20,7 @@ import ( "encoding/binary" "errors" "fmt" + "go/token" "internal/nettrace" "io" "io/ioutil" @@ -42,7 +43,7 @@ import ( "testing" "time" - "internal/x/net/http/httpguts" + "golang.org/x/net/http/httpguts" ) // TODO: test 5 pipelined requests with responses: 1) OK, 2) OK, Connection: Close @@ -588,6 +589,106 @@ func TestTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T) { <-reqComplete } +func TestTransportMaxConnsPerHost(t *testing.T) { + defer afterTest(t) + + h := HandlerFunc(func(w ResponseWriter, r *Request) { + _, err := w.Write([]byte("foo")) + if err != nil { + t.Fatalf("Write: %v", err) + } + }) + + testMaxConns := func(scheme string, ts *httptest.Server) { + defer ts.Close() + + c := ts.Client() + tr := c.Transport.(*Transport) + tr.MaxConnsPerHost = 1 + if err := ExportHttp2ConfigureTransport(tr); err != nil { + t.Fatalf("ExportHttp2ConfigureTransport: %v", err) + } + + connCh := make(chan net.Conn, 1) + var dialCnt, gotConnCnt, tlsHandshakeCnt int32 + tr.Dial = func(network, addr string) (net.Conn, error) { + atomic.AddInt32(&dialCnt, 1) + c, err := net.Dial(network, addr) + connCh <- c + return c, err + } + + doReq := func() { + trace := &httptrace.ClientTrace{ + GotConn: func(connInfo httptrace.GotConnInfo) { + if !connInfo.Reused { + atomic.AddInt32(&gotConnCnt, 1) + } + }, + TLSHandshakeStart: func() { + atomic.AddInt32(&tlsHandshakeCnt, 1) + }, + } + req, _ := NewRequest("GET", ts.URL, nil) + req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) + + resp, err := c.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read body failed: %v", err) + } + } + + wg := sync.WaitGroup{} + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + doReq() + }() + } + wg.Wait() + + expected := int32(tr.MaxConnsPerHost) + if dialCnt != expected { + t.Errorf("Too many dials (%s): %d", scheme, dialCnt) + } + if gotConnCnt != expected { + t.Errorf("Too many get connections (%s): %d", scheme, gotConnCnt) + } + if ts.TLS != nil && tlsHandshakeCnt != expected { + t.Errorf("Too many tls handshakes (%s): %d", scheme, tlsHandshakeCnt) + } + + (<-connCh).Close() + tr.CloseIdleConnections() + + doReq() + expected++ + if dialCnt != expected { + t.Errorf("Too many dials (%s): %d", scheme, dialCnt) + } + if gotConnCnt != expected { + t.Errorf("Too many get connections (%s): %d", scheme, gotConnCnt) + } + if ts.TLS != nil && tlsHandshakeCnt != expected { + t.Errorf("Too many tls handshakes (%s): %d", scheme, tlsHandshakeCnt) + } + } + + testMaxConns("http", httptest.NewServer(h)) + testMaxConns("https", httptest.NewTLSServer(h)) + + ts := httptest.NewUnstartedServer(h) + ts.TLS = &tls.Config{NextProtos: []string{"h2"}} + ts.StartTLS() + testMaxConns("http2", ts) +} + func TestTransportRemovesDeadIdleConnections(t *testing.T) { setParallel(t) defer afterTest(t) @@ -636,6 +737,8 @@ func TestTransportRemovesDeadIdleConnections(t *testing.T) { } } +// Test that the Transport notices when a server hangs up on its +// unexpectedly (a keep-alive connection is closed). func TestTransportServerClosingUnexpectedly(t *testing.T) { setParallel(t) defer afterTest(t) @@ -672,13 +775,14 @@ func TestTransportServerClosingUnexpectedly(t *testing.T) { body1 := fetch(1, 0) body2 := fetch(2, 0) - ts.CloseClientConnections() // surprise! - - // This test has an expected race. Sleeping for 25 ms prevents - // it on most fast machines, causing the next fetch() call to - // succeed quickly. But if we do get errors, fetch() will retry 5 - // times with some delays between. - time.Sleep(25 * time.Millisecond) + // Close all the idle connections in a way that's similar to + // the server hanging up on us. We don't use + // httptest.Server.CloseClientConnections because it's + // best-effort and stops blocking after 5 seconds. On a loaded + // machine running many tests concurrently it's possible for + // that method to be async and cause the body3 fetch below to + // run on an old connection. This function is synchronous. + ExportCloseTransportConnsAbruptly(c.Transport.(*Transport)) body3 := fetch(3, 5) @@ -865,6 +969,10 @@ func TestRoundTripGzip(t *testing.T) { req.Header.Set("Accept-Encoding", test.accept) } res, err := tr.RoundTrip(req) + if err != nil { + t.Errorf("%d. RoundTrip: %v", i, err) + continue + } var body []byte if test.compressed { var r *gzip.Reader @@ -2110,7 +2218,7 @@ func testCancelRequestWithChannelBeforeDo(t *testing.T, withCtx bool) { } } else { if err == nil || !strings.Contains(err.Error(), "canceled") { - t.Errorf("Do error = %v; want cancelation", err) + t.Errorf("Do error = %v; want cancellation", err) } } } @@ -3589,6 +3697,13 @@ func TestTransportAutomaticHTTP2(t *testing.T) { testTransportAutoHTTP(t, &Transport{}, true) } +func TestTransportAutomaticHTTP2_DialerAndTLSConfigSupportsHTTP2AndTLSConfig(t *testing.T) { + testTransportAutoHTTP(t, &Transport{ + ForceAttemptHTTP2: true, + TLSClientConfig: new(tls.Config), + }, true) +} + // golang.org/issue/14391: also check DefaultTransport func TestTransportAutomaticHTTP2_DefaultTransport(t *testing.T) { testTransportAutoHTTP(t, DefaultTransport.(*Transport), true) @@ -3619,6 +3734,13 @@ func TestTransportAutomaticHTTP2_Dial(t *testing.T) { }, false) } +func TestTransportAutomaticHTTP2_DialContext(t *testing.T) { + var d net.Dialer + testTransportAutoHTTP(t, &Transport{ + DialContext: d.DialContext, + }, false) +} + func TestTransportAutomaticHTTP2_DialTLS(t *testing.T) { testTransportAutoHTTP(t, &Transport{ DialTLS: func(network, addr string) (net.Conn, error) { @@ -5059,3 +5181,270 @@ func TestTransportRequestReplayable(t *testing.T) { }) } } + +// testMockTCPConn is a mock TCP connection used to test that +// ReadFrom is called when sending the request body. +type testMockTCPConn struct { + *net.TCPConn + + ReadFromCalled bool +} + +func (c *testMockTCPConn) ReadFrom(r io.Reader) (int64, error) { + c.ReadFromCalled = true + return c.TCPConn.ReadFrom(r) +} + +func TestTransportRequestWriteRoundTrip(t *testing.T) { + nBytes := int64(1 << 10) + newFileFunc := func() (r io.Reader, done func(), err error) { + f, err := ioutil.TempFile("", "net-http-newfilefunc") + if err != nil { + return nil, nil, err + } + + // Write some bytes to the file to enable reading. + if _, err := io.CopyN(f, rand.Reader, nBytes); err != nil { + return nil, nil, fmt.Errorf("failed to write data to file: %v", err) + } + if _, err := f.Seek(0, 0); err != nil { + return nil, nil, fmt.Errorf("failed to seek to front: %v", err) + } + + done = func() { + f.Close() + os.Remove(f.Name()) + } + + return f, done, nil + } + + newBufferFunc := func() (io.Reader, func(), error) { + return bytes.NewBuffer(make([]byte, nBytes)), func() {}, nil + } + + cases := []struct { + name string + readerFunc func() (io.Reader, func(), error) + contentLength int64 + expectedReadFrom bool + }{ + { + name: "file, length", + readerFunc: newFileFunc, + contentLength: nBytes, + expectedReadFrom: true, + }, + { + name: "file, no length", + readerFunc: newFileFunc, + }, + { + name: "file, negative length", + readerFunc: newFileFunc, + contentLength: -1, + }, + { + name: "buffer", + contentLength: nBytes, + readerFunc: newBufferFunc, + }, + { + name: "buffer, no length", + readerFunc: newBufferFunc, + }, + { + name: "buffer, length -1", + contentLength: -1, + readerFunc: newBufferFunc, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + r, cleanup, err := tc.readerFunc() + if err != nil { + t.Fatal(err) + } + defer cleanup() + + tConn := &testMockTCPConn{} + trFunc := func(tr *Transport) { + tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + var d net.Dialer + conn, err := d.DialContext(ctx, network, addr) + if err != nil { + return nil, err + } + + tcpConn, ok := conn.(*net.TCPConn) + if !ok { + return nil, fmt.Errorf("%s/%s does not provide a *net.TCPConn", network, addr) + } + + tConn.TCPConn = tcpConn + return tConn, nil + } + } + + cst := newClientServerTest( + t, + h1Mode, + HandlerFunc(func(w ResponseWriter, r *Request) { + io.Copy(ioutil.Discard, r.Body) + r.Body.Close() + w.WriteHeader(200) + }), + trFunc, + ) + defer cst.close() + + req, err := NewRequest("PUT", cst.ts.URL, r) + if err != nil { + t.Fatal(err) + } + req.ContentLength = tc.contentLength + req.Header.Set("Content-Type", "application/octet-stream") + resp, err := cst.c.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + t.Fatalf("status code = %d; want 200", resp.StatusCode) + } + + if !tConn.ReadFromCalled && tc.expectedReadFrom { + t.Fatalf("did not call ReadFrom") + } + + if tConn.ReadFromCalled && !tc.expectedReadFrom { + t.Fatalf("ReadFrom was unexpectedly invoked") + } + }) + } +} + +func TestTransportClone(t *testing.T) { + tr := &Transport{ + Proxy: func(*Request) (*url.URL, error) { panic("") }, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") }, + Dial: func(network, addr string) (net.Conn, error) { panic("") }, + DialTLS: func(network, addr string) (net.Conn, error) { panic("") }, + TLSClientConfig: new(tls.Config), + TLSHandshakeTimeout: time.Second, + DisableKeepAlives: true, + DisableCompression: true, + MaxIdleConns: 1, + MaxIdleConnsPerHost: 1, + MaxConnsPerHost: 1, + IdleConnTimeout: time.Second, + ResponseHeaderTimeout: time.Second, + ExpectContinueTimeout: time.Second, + ProxyConnectHeader: Header{}, + MaxResponseHeaderBytes: 1, + ForceAttemptHTTP2: true, + TLSNextProto: map[string]func(authority string, c *tls.Conn) RoundTripper{ + "foo": func(authority string, c *tls.Conn) RoundTripper { panic("") }, + }, + ReadBufferSize: 1, + WriteBufferSize: 1, + } + tr2 := tr.Clone() + rv := reflect.ValueOf(tr2).Elem() + rt := rv.Type() + for i := 0; i < rt.NumField(); i++ { + sf := rt.Field(i) + if !token.IsExported(sf.Name) { + continue + } + if rv.Field(i).IsZero() { + t.Errorf("cloned field t2.%s is zero", sf.Name) + } + } + + if _, ok := tr2.TLSNextProto["foo"]; !ok { + t.Errorf("cloned Transport lacked TLSNextProto 'foo' key") + } + + // But test that a nil TLSNextProto is kept nil: + tr = new(Transport) + tr2 = tr.Clone() + if tr2.TLSNextProto != nil { + t.Errorf("Transport.TLSNextProto unexpected non-nil") + } +} + +func TestIs408(t *testing.T) { + tests := []struct { + in string + want bool + }{ + {"HTTP/1.0 408", true}, + {"HTTP/1.1 408", true}, + {"HTTP/1.8 408", true}, + {"HTTP/2.0 408", false}, // maybe h2c would do this? but false for now. + {"HTTP/1.1 408 ", true}, + {"HTTP/1.1 40", false}, + {"http/1.0 408", false}, + {"HTTP/1-1 408", false}, + } + for _, tt := range tests { + if got := Export_is408Message([]byte(tt.in)); got != tt.want { + t.Errorf("is408Message(%q) = %v; want %v", tt.in, got, tt.want) + } + } +} + +func TestTransportIgnores408(t *testing.T) { + // Not parallel. Relies on mutating the log package's global Output. + defer log.SetOutput(log.Writer()) + + var logout bytes.Buffer + log.SetOutput(&logout) + + defer afterTest(t) + const target = "backend:443" + + cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + nc, _, err := w.(Hijacker).Hijack() + if err != nil { + t.Error(err) + return + } + defer nc.Close() + nc.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok")) + nc.Write([]byte("HTTP/1.1 408 bye\r\n")) // changing 408 to 409 makes test fail + })) + defer cst.close() + req, err := NewRequest("GET", cst.ts.URL, nil) + if err != nil { + t.Fatal(err) + } + res, err := cst.c.Do(req) + if err != nil { + t.Fatal(err) + } + slurp, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if err != nil { + t.Fatal(err) + } + if string(slurp) != "ok" { + t.Fatalf("got %q; want ok", slurp) + } + + t0 := time.Now() + for i := 0; i < 50; i++ { + time.Sleep(time.Duration(i) * 5 * time.Millisecond) + if cst.tr.IdleConnKeyCountForTesting() == 0 { + if got := logout.String(); got != "" { + t.Fatalf("expected no log output; got: %s", got) + } + return + } + } + t.Fatalf("timeout after %v waiting for Transport connections to die off", time.Since(t0)) +} |