diff options
Diffstat (limited to 'libgo/go/crypto/tls/tls_test.go')
-rw-r--r-- | libgo/go/crypto/tls/tls_test.go | 150 |
1 files changed, 135 insertions, 15 deletions
diff --git a/libgo/go/crypto/tls/tls_test.go b/libgo/go/crypto/tls/tls_test.go index 48b46a003a6..8933f4f2015 100644 --- a/libgo/go/crypto/tls/tls_test.go +++ b/libgo/go/crypto/tls/tls_test.go @@ -11,6 +11,7 @@ import ( "fmt" "internal/testenv" "io" + "io/ioutil" "math" "math/rand" "net" @@ -98,6 +99,7 @@ var keyPairTests = []struct { } func TestX509KeyPair(t *testing.T) { + t.Parallel() var pem []byte for _, test := range keyPairTests { pem = []byte(test.cert + test.key) @@ -241,7 +243,7 @@ func testConnReadNonzeroAndEOF(t *testing.T, delay time.Duration) error { srvCh <- nil return } - serverConfig := testConfig.clone() + serverConfig := testConfig.Clone() srv := Server(sconn, serverConfig) if err := srv.Handshake(); err != nil { serr = fmt.Errorf("handshake: %v", err) @@ -251,7 +253,7 @@ func testConnReadNonzeroAndEOF(t *testing.T, delay time.Duration) error { srvCh <- srv }() - clientConfig := testConfig.clone() + clientConfig := testConfig.Clone() conn, err := Dial("tcp", ln.Addr().String(), clientConfig) if err != nil { t.Fatal(err) @@ -293,18 +295,20 @@ func TestTLSUniqueMatches(t *testing.T) { for i := 0; i < 2; i++ { sconn, err := ln.Accept() if err != nil { - t.Fatal(err) + t.Error(err) + return } - serverConfig := testConfig.clone() + serverConfig := testConfig.Clone() srv := Server(sconn, serverConfig) if err := srv.Handshake(); err != nil { - t.Fatal(err) + t.Error(err) + return } serverTLSUniques <- srv.ConnectionState().TLSUnique } }() - clientConfig := testConfig.clone() + clientConfig := testConfig.Clone() clientConfig.ClientSessionCache = NewLRUClientSessionCache(1) conn, err := Dial("tcp", ln.Addr().String(), clientConfig) if err != nil { @@ -394,7 +398,7 @@ func TestConnCloseBreakingWrite(t *testing.T) { srvCh <- nil return } - serverConfig := testConfig.clone() + serverConfig := testConfig.Clone() srv := Server(sconn, serverConfig) if err := srv.Handshake(); err != nil { serr = fmt.Errorf("handshake: %v", err) @@ -414,7 +418,7 @@ func TestConnCloseBreakingWrite(t *testing.T) { Conn: cconn, } - clientConfig := testConfig.clone() + clientConfig := testConfig.Clone() tconn := Client(conn, clientConfig) if err := tconn.Handshake(); err != nil { t.Fatal(err) @@ -458,6 +462,112 @@ func TestConnCloseBreakingWrite(t *testing.T) { } } +func TestConnCloseWrite(t *testing.T) { + ln := newLocalListener(t) + defer ln.Close() + + clientDoneChan := make(chan struct{}) + + serverCloseWrite := func() error { + sconn, err := ln.Accept() + if err != nil { + return fmt.Errorf("accept: %v", err) + } + defer sconn.Close() + + serverConfig := testConfig.Clone() + srv := Server(sconn, serverConfig) + if err := srv.Handshake(); err != nil { + return fmt.Errorf("handshake: %v", err) + } + defer srv.Close() + + data, err := ioutil.ReadAll(srv) + if err != nil { + return err + } + if len(data) > 0 { + return fmt.Errorf("Read data = %q; want nothing", data) + } + + if err := srv.CloseWrite(); err != nil { + return fmt.Errorf("server CloseWrite: %v", err) + } + + // Wait for clientCloseWrite to finish, so we know we + // tested the CloseWrite before we defer the + // sconn.Close above, which would also cause the + // client to unblock like CloseWrite. + <-clientDoneChan + return nil + } + + clientCloseWrite := func() error { + defer close(clientDoneChan) + + clientConfig := testConfig.Clone() + conn, err := Dial("tcp", ln.Addr().String(), clientConfig) + if err != nil { + return err + } + if err := conn.Handshake(); err != nil { + return err + } + defer conn.Close() + + if err := conn.CloseWrite(); err != nil { + return fmt.Errorf("client CloseWrite: %v", err) + } + + if _, err := conn.Write([]byte{0}); err != errShutdown { + return fmt.Errorf("CloseWrite error = %v; want errShutdown", err) + } + + data, err := ioutil.ReadAll(conn) + if err != nil { + return err + } + if len(data) > 0 { + return fmt.Errorf("Read data = %q; want nothing", data) + } + return nil + } + + errChan := make(chan error, 2) + + go func() { errChan <- serverCloseWrite() }() + go func() { errChan <- clientCloseWrite() }() + + for i := 0; i < 2; i++ { + select { + case err := <-errChan: + if err != nil { + t.Fatal(err) + } + case <-time.After(10 * time.Second): + t.Fatal("deadlock") + } + } + + // Also test CloseWrite being called before the handshake is + // finished: + { + ln2 := newLocalListener(t) + defer ln2.Close() + + netConn, err := net.Dial("tcp", ln2.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer netConn.Close() + conn := Client(netConn, testConfig.Clone()) + + if err := conn.CloseWrite(); err != errEarlyCloseWrite { + t.Errorf("CloseWrite error = %v; want errEarlyCloseWrite", err) + } + } +} + func TestClone(t *testing.T) { var c1 Config v := reflect.ValueOf(&c1).Elem() @@ -477,12 +587,12 @@ func TestClone(t *testing.T) { case "Rand": f.Set(reflect.ValueOf(io.Reader(os.Stdin))) continue - case "Time", "GetCertificate": + case "Time", "GetCertificate", "GetConfigForClient", "VerifyPeerCertificate", "GetClientCertificate": // DeepEqual can't compare functions. continue case "Certificates": f.Set(reflect.ValueOf([]Certificate{ - {Certificate: [][]byte{[]byte{'b'}}}, + {Certificate: [][]byte{{'b'}}}, })) continue case "NameToCertificate": @@ -494,6 +604,10 @@ func TestClone(t *testing.T) { case "ClientSessionCache": f.Set(reflect.ValueOf(NewLRUClientSessionCache(10))) continue + case "KeyLogWriter": + f.Set(reflect.ValueOf(io.Writer(os.Stdout))) + continue + } q, ok := quick.Value(f.Type(), rnd) @@ -503,7 +617,11 @@ func TestClone(t *testing.T) { f.Set(q) } - c2 := c1.clone() + c2 := c1.Clone() + // DeepEqual also compares unexported fields, thus c2 needs to have run + // serverInit in order to be DeepEqual to c1. Cloning it and discarding + // the result is sufficient. + c2.Clone() if !reflect.DeepEqual(&c1, c2) { t.Errorf("clone failed to copy a field") @@ -551,7 +669,8 @@ func throughput(b *testing.B, totalBytes int64, dynamicRecordSizingDisabled bool // (cannot call b.Fatal in goroutine) panic(fmt.Errorf("accept: %v", err)) } - serverConfig := testConfig.clone() + serverConfig := testConfig.Clone() + serverConfig.CipherSuites = nil // the defaults may prefer faster ciphers serverConfig.DynamicRecordSizingDisabled = dynamicRecordSizingDisabled srv := Server(sconn, serverConfig) if err := srv.Handshake(); err != nil { @@ -564,7 +683,8 @@ func throughput(b *testing.B, totalBytes int64, dynamicRecordSizingDisabled bool }() b.SetBytes(totalBytes) - clientConfig := testConfig.clone() + clientConfig := testConfig.Clone() + clientConfig.CipherSuites = nil // the defaults may prefer faster ciphers clientConfig.DynamicRecordSizingDisabled = dynamicRecordSizingDisabled buf := make([]byte, bufsize) @@ -641,7 +761,7 @@ func latency(b *testing.B, bps int, dynamicRecordSizingDisabled bool) { // (cannot call b.Fatal in goroutine) panic(fmt.Errorf("accept: %v", err)) } - serverConfig := testConfig.clone() + serverConfig := testConfig.Clone() serverConfig.DynamicRecordSizingDisabled = dynamicRecordSizingDisabled srv := Server(&slowConn{sconn, bps}, serverConfig) if err := srv.Handshake(); err != nil { @@ -651,7 +771,7 @@ func latency(b *testing.B, bps int, dynamicRecordSizingDisabled bool) { } }() - clientConfig := testConfig.clone() + clientConfig := testConfig.Clone() clientConfig.DynamicRecordSizingDisabled = dynamicRecordSizingDisabled buf := make([]byte, 16384) |