diff options
author | Ian Lance Taylor <iant@golang.org> | 2018-01-09 01:23:08 +0000 |
---|---|---|
committer | Ian Lance Taylor <ian@gcc.gnu.org> | 2018-01-09 01:23:08 +0000 |
commit | 1a2f01efa63036a5104f203a4789e682c0e0915d (patch) | |
tree | 373e15778dc8295354584e1f86915ae493b604ff /libgo/go/net | |
parent | 8799df67f2dab88f9fda11739c501780a85575e2 (diff) |
libgo: update to Go1.10beta1
Update the Go library to the 1.10beta1 release.
Requires a few changes to the compiler for modifications to the map
runtime code, and to handle some nowritebarrier cases in the runtime.
Reviewed-on: https://go-review.googlesource.com/86455
gotools/:
* Makefile.am (go_cmd_vet_files): New variable.
(go_cmd_buildid_files, go_cmd_test2json_files): New variables.
(s-zdefaultcc): Change from constants to functions.
(noinst_PROGRAMS): Add vet, buildid, and test2json.
(cgo$(EXEEXT)): Link against $(LIBGOTOOL).
(vet$(EXEEXT)): New target.
(buildid$(EXEEXT)): New target.
(test2json$(EXEEXT)): New target.
(install-exec-local): Install all $(noinst_PROGRAMS).
(uninstall-local): Uninstasll all $(noinst_PROGRAMS).
(check-go-tool): Depend on $(noinst_PROGRAMS). Copy down
objabi.go.
(check-runtime): Depend on $(noinst_PROGRAMS).
(check-cgo-test, check-carchive-test): Likewise.
(check-vet): New target.
(check): Depend on check-vet. Look at cmd_vet-testlog.
(.PHONY): Add check-vet.
* Makefile.in: Rebuild.
From-SVN: r256365
Diffstat (limited to 'libgo/go/net')
70 files changed, 2942 insertions, 794 deletions
diff --git a/libgo/go/net/cgo_unix.go b/libgo/go/net/cgo_unix.go index 0de3ff8bebd..5866d384820 100644 --- a/libgo/go/net/cgo_unix.go +++ b/libgo/go/net/cgo_unix.go @@ -12,7 +12,6 @@ package net #include <sys/socket.h> #include <netinet/in.h> #include <netdb.h> -#include <stdlib.h> #include <unistd.h> #include <string.h> */ diff --git a/libgo/go/net/dial_test.go b/libgo/go/net/dial_test.go index a892bf1e140..13fa9faacb5 100644 --- a/libgo/go/net/dial_test.go +++ b/libgo/go/net/dial_test.go @@ -161,6 +161,8 @@ func dialClosedPort() (actual, expected time.Duration) { // but other platforms should be instantaneous. if runtime.GOOS == "windows" { expected = 1500 * time.Millisecond + } else if runtime.GOOS == "darwin" { + expected = 150 * time.Millisecond } else { expected = 95 * time.Millisecond } diff --git a/libgo/go/net/fd_windows.go b/libgo/go/net/fd_windows.go index c2156b255e5..e5f8da156a2 100644 --- a/libgo/go/net/fd_windows.go +++ b/libgo/go/net/fd_windows.go @@ -52,7 +52,7 @@ func newFD(sysfd syscall.Handle, family, sotype int, net string) (*netFD, error) } func (fd *netFD) init() error { - errcall, err := fd.pfd.Init(fd.net) + errcall, err := fd.pfd.Init(fd.net, true) if errcall != "" { err = wrapSyscallError(errcall, err) } @@ -223,17 +223,21 @@ func (fd *netFD) accept() (*netFD, error) { return netfd, nil } +func (fd *netFD) readMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.Sockaddr, err error) { + n, oobn, flags, sa, err = fd.pfd.ReadMsg(p, oob) + runtime.KeepAlive(fd) + return n, oobn, flags, sa, wrapSyscallError("wsarecvmsg", err) +} + +func (fd *netFD) writeMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oobn int, err error) { + n, oobn, err = fd.pfd.WriteMsg(p, oob, sa) + runtime.KeepAlive(fd) + return n, oobn, wrapSyscallError("wsasendmsg", err) +} + // Unimplemented functions. func (fd *netFD) dup() (*os.File, error) { // TODO: Implement this return nil, syscall.EWINDOWS } - -func (fd *netFD) readMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.Sockaddr, err error) { - return 0, 0, 0, nil, syscall.EWINDOWS -} - -func (fd *netFD) writeMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oobn int, err error) { - return 0, 0, syscall.EWINDOWS -} diff --git a/libgo/go/net/hook_windows.go b/libgo/go/net/hook_windows.go index 4e64dcef517..ab8656cbbf3 100644 --- a/libgo/go/net/hook_windows.go +++ b/libgo/go/net/hook_windows.go @@ -5,6 +5,7 @@ package net import ( + "internal/syscall/windows" "syscall" "time" ) @@ -13,7 +14,8 @@ var ( testHookDialChannel = func() { time.Sleep(time.Millisecond) } // see golang.org/issue/5349 // Placeholders for socket system calls. - socketFunc func(int, int, int) (syscall.Handle, error) = syscall.Socket - connectFunc func(syscall.Handle, syscall.Sockaddr) error = syscall.Connect - listenFunc func(syscall.Handle, int) error = syscall.Listen + socketFunc func(int, int, int) (syscall.Handle, error) = syscall.Socket + wsaSocketFunc func(int32, int32, int32, *syscall.WSAProtocolInfo, uint32, uint32) (syscall.Handle, error) = windows.WSASocket + connectFunc func(syscall.Handle, syscall.Sockaddr) error = syscall.Connect + listenFunc func(syscall.Handle, int) error = syscall.Listen ) diff --git a/libgo/go/net/hosts_test.go b/libgo/go/net/hosts_test.go index 5d6c9cfe190..f850e2fccfd 100644 --- a/libgo/go/net/hosts_test.go +++ b/libgo/go/net/hosts_test.go @@ -150,7 +150,7 @@ func testStaticAddr(t *testing.T, hostsPath string, ent staticHostEntry) { func TestHostCacheModification(t *testing.T) { // Ensure that programs can't modify the internals of the host cache. - // See https://github.com/golang/go/issues/14212. + // See https://golang.org/issues/14212. defer func(orig string) { testHookHostsPath = orig }(testHookHostsPath) testHookHostsPath = "testdata/ipv4-hosts" diff --git a/libgo/go/net/http/client.go b/libgo/go/net/http/client.go index 4c9084ae512..6f6024ed4d8 100644 --- a/libgo/go/net/http/client.go +++ b/libgo/go/net/http/client.go @@ -127,7 +127,10 @@ type RoundTripper interface { // authentication, or cookies. // // RoundTrip should not modify the request, except for - // consuming and closing the Request's Body. + // consuming and closing the Request's Body. RoundTrip may + // read fields of the request in a separate goroutine. Callers + // should not mutate the request until the Response's Body has + // been closed. // // RoundTrip must always close the body, including on errors, // but depending on the implementation may do so in a separate @@ -536,12 +539,22 @@ func (c *Client) Do(req *Request) (*Response, error) { resp.closeBody() return nil, uerr(fmt.Errorf("failed to parse Location header %q: %v", loc, err)) } + host := "" + if req.Host != "" && req.Host != req.URL.Host { + // If the caller specified a custom Host header and the + // redirect location is relative, preserve the Host header + // through the redirect. See issue #22233. + if u, _ := url.Parse(loc); u != nil && !u.IsAbs() { + host = req.Host + } + } ireq := reqs[0] req = &Request{ Method: redirectMethod, Response: resp, URL: u, Header: make(Header), + Host: host, Cancel: ireq.Cancel, ctx: ireq.ctx, } @@ -750,7 +763,7 @@ func PostForm(url string, data url.Values) (resp *Response, err error) { // with data's keys and values URL-encoded as the request body. // // The Content-Type header is set to application/x-www-form-urlencoded. -// To set other headers, use NewRequest and DefaultClient.Do. +// To set other headers, use NewRequest and Client.Do. // // When err is nil, resp always contains a non-nil resp.Body. // Caller should close resp.Body when done reading from it. @@ -843,16 +856,8 @@ func shouldCopyHeaderOnRedirect(headerKey string, initial, dest *url.URL) bool { // directly, we don't know their scope, so we assume // it's for *.domain.com. - // TODO(bradfitz): once issue 16142 is fixed, make - // this code use those URL accessors, and consider - // "http://foo.com" and "http://foo.com:80" as - // equivalent? - - // TODO(bradfitz): better hostname canonicalization, - // at least once we figure out IDNA/Punycode (issue - // 13835). - ihost := strings.ToLower(initial.Host) - dhost := strings.ToLower(dest.Host) + ihost := canonicalAddr(initial) + dhost := canonicalAddr(dest) return isDomainOrSubdomain(dhost, ihost) } // All other headers are copied: diff --git a/libgo/go/net/http/client_test.go b/libgo/go/net/http/client_test.go index b9a1c31e43a..eea3b16fb3b 100644 --- a/libgo/go/net/http/client_test.go +++ b/libgo/go/net/http/client_test.go @@ -1426,7 +1426,7 @@ func TestClientRedirectResponseWithoutRequest(t *testing.T) { c.Get("http://dummy.tld") } -// Issue 4800: copy (some) headers when Client follows a redirect +// Issue 4800: copy (some) headers when Client follows a redirect. func TestClientCopyHeadersOnRedirect(t *testing.T) { const ( ua = "some-agent/1.2" @@ -1487,6 +1487,76 @@ func TestClientCopyHeadersOnRedirect(t *testing.T) { } } +// Issue 22233: copy host when Client follows a relative redirect. +func TestClientCopyHostOnRedirect(t *testing.T) { + // Virtual hostname: should not receive any request. + virtual := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + t.Errorf("Virtual host received request %v", r.URL) + w.WriteHeader(403) + io.WriteString(w, "should not see this response") + })) + defer virtual.Close() + virtualHost := strings.TrimPrefix(virtual.URL, "http://") + t.Logf("Virtual host is %v", virtualHost) + + // Actual hostname: should not receive any request. + const wantBody = "response body" + var tsURL string + var tsHost string + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + switch r.URL.Path { + case "/": + // Relative redirect. + if r.Host != virtualHost { + t.Errorf("Serving /: Request.Host = %#v; want %#v", r.Host, virtualHost) + w.WriteHeader(404) + return + } + w.Header().Set("Location", "/hop") + w.WriteHeader(302) + case "/hop": + // Absolute redirect. + if r.Host != virtualHost { + t.Errorf("Serving /hop: Request.Host = %#v; want %#v", r.Host, virtualHost) + w.WriteHeader(404) + return + } + w.Header().Set("Location", tsURL+"/final") + w.WriteHeader(302) + case "/final": + if r.Host != tsHost { + t.Errorf("Serving /final: Request.Host = %#v; want %#v", r.Host, tsHost) + w.WriteHeader(404) + return + } + w.WriteHeader(200) + io.WriteString(w, wantBody) + default: + t.Errorf("Serving unexpected path %q", r.URL.Path) + w.WriteHeader(404) + } + })) + defer ts.Close() + tsURL = ts.URL + tsHost = strings.TrimPrefix(ts.URL, "http://") + t.Logf("Server host is %v", tsHost) + + c := ts.Client() + req, _ := NewRequest("GET", ts.URL, nil) + req.Host = virtualHost + resp, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + t.Fatal(resp.Status) + } + if got, err := ioutil.ReadAll(resp.Body); err != nil || string(got) != wantBody { + t.Errorf("body = %q; want %q", got, wantBody) + } +} + // Issue 17494: cookies should be altered when Client follows redirects. func TestClientAltersCookiesOnRedirect(t *testing.T) { cookieMap := func(cs []*Cookie) map[string][]string { @@ -1599,8 +1669,12 @@ func TestShouldCopyHeaderOnRedirect(t *testing.T) { {"www-authenticate", "http://foo.com/", "http://foo.com/", true}, {"www-authenticate", "http://foo.com/", "http://sub.foo.com/", true}, {"www-authenticate", "http://foo.com/", "http://notfoo.com/", false}, - // TODO(bradfitz): make this test work, once issue 16142 is fixed: - // {"www-authenticate", "http://foo.com:80/", "http://foo.com/", true}, + {"www-authenticate", "http://foo.com/", "https://foo.com/", false}, + {"www-authenticate", "http://foo.com:80/", "http://foo.com/", true}, + {"www-authenticate", "http://foo.com:80/", "http://sub.foo.com/", true}, + {"www-authenticate", "http://foo.com:443/", "https://foo.com/", true}, + {"www-authenticate", "http://foo.com:443/", "https://sub.foo.com/", true}, + {"www-authenticate", "http://foo.com:1234/", "http://foo.com/", false}, } for i, tt := range tests { u0, err := url.Parse(tt.initialURL) diff --git a/libgo/go/net/http/clientserver_test.go b/libgo/go/net/http/clientserver_test.go index 20feaa70ff6..c8d9fab8b7a 100644 --- a/libgo/go/net/http/clientserver_test.go +++ b/libgo/go/net/http/clientserver_test.go @@ -1145,27 +1145,6 @@ func testTransportRejectsInvalidHeaders(t *testing.T, h2 bool) { } } -// Tests that we support bogus under-100 HTTP statuses, because we historically -// have. This might change at some point, but not yet in Go 1.6. -func TestBogusStatusWorks_h1(t *testing.T) { testBogusStatusWorks(t, h1Mode) } -func TestBogusStatusWorks_h2(t *testing.T) { testBogusStatusWorks(t, h2Mode) } -func testBogusStatusWorks(t *testing.T, h2 bool) { - defer afterTest(t) - const code = 7 - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { - w.WriteHeader(code) - })) - defer cst.close() - - res, err := cst.c.Get(cst.ts.URL) - if err != nil { - t.Fatal(err) - } - if res.StatusCode != code { - t.Errorf("StatusCode = %d; want %d", res.StatusCode, code) - } -} - func TestInterruptWithPanic_h1(t *testing.T) { testInterruptWithPanic(t, h1Mode, "boom") } func TestInterruptWithPanic_h2(t *testing.T) { testInterruptWithPanic(t, h2Mode, "boom") } func TestInterruptWithPanic_nil_h1(t *testing.T) { testInterruptWithPanic(t, h1Mode, nil) } @@ -1412,3 +1391,40 @@ func TestBadResponseAfterReadingBody(t *testing.T) { t.Errorf("closes = %d; want 1", closes) } } + +func TestWriteHeader0_h1(t *testing.T) { testWriteHeader0(t, h1Mode) } +func TestWriteHeader0_h2(t *testing.T) { testWriteHeader0(t, h2Mode) } +func testWriteHeader0(t *testing.T, h2 bool) { + defer afterTest(t) + gotpanic := make(chan bool, 1) + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + defer close(gotpanic) + defer func() { + if e := recover(); e != nil { + got := fmt.Sprintf("%T, %v", e, e) + want := "string, invalid WriteHeader code 0" + if got != want { + t.Errorf("unexpected panic value:\n got: %v\nwant: %v\n", got, want) + } + gotpanic <- true + + // Set an explicit 503. This also tests that the WriteHeader call panics + // before it recorded that an explicit value was set and that bogus + // value wasn't stuck. + w.WriteHeader(503) + } + }() + w.WriteHeader(0) + })) + defer cst.close() + res, err := cst.c.Get(cst.ts.URL) + if err != nil { + t.Fatal(err) + } + if res.StatusCode != 503 { + t.Errorf("Response: %v %q; want 503", res.StatusCode, res.Status) + } + if !<-gotpanic { + t.Error("expected panic in handler") + } +} diff --git a/libgo/go/net/http/cookie.go b/libgo/go/net/http/cookie.go index cf522488c15..38b1b3630e2 100644 --- a/libgo/go/net/http/cookie.go +++ b/libgo/go/net/http/cookie.go @@ -208,7 +208,6 @@ func readCookies(h Header, filter string) []*Cookie { continue } // Per-line attributes - parsedPairs := 0 for i := 0; i < len(parts); i++ { parts[i] = strings.TrimSpace(parts[i]) if len(parts[i]) == 0 { @@ -229,7 +228,6 @@ func readCookies(h Header, filter string) []*Cookie { continue } cookies = append(cookies, &Cookie{Name: name, Value: val}) - parsedPairs++ } } return cookies diff --git a/libgo/go/net/http/example_test.go b/libgo/go/net/http/example_test.go index 1774795d379..9de0893e873 100644 --- a/libgo/go/net/http/example_test.go +++ b/libgo/go/net/http/example_test.go @@ -5,11 +5,14 @@ package http_test import ( + "context" "fmt" "io" "io/ioutil" "log" "net/http" + "os" + "os/signal" ) func ExampleHijacker() { @@ -109,3 +112,28 @@ func ExampleResponseWriter_trailers() { w.Header().Set("AtEnd3", "value 3") // These will appear as trailers. }) } + +func ExampleServer_Shutdown() { + var srv http.Server + + idleConnsClosed := make(chan struct{}) + go func() { + sigint := make(chan os.Signal, 1) + signal.Notify(sigint, os.Interrupt) + <-sigint + + // We received an interrupt signal, shut down. + if err := srv.Shutdown(context.Background()); err != nil { + // Error from closing listeners, or context timeout: + log.Printf("HTTP server Shutdown: %v", err) + } + close(idleConnsClosed) + }() + + if err := srv.ListenAndServe(); err != http.ErrServerClosed { + // Error starting or closing listener: + log.Printf("HTTP server ListenAndServe: %v", err) + } + + <-idleConnsClosed +} diff --git a/libgo/go/net/http/export_test.go b/libgo/go/net/http/export_test.go index 2ef145e5342..1825acd9be7 100644 --- a/libgo/go/net/http/export_test.go +++ b/libgo/go/net/http/export_test.go @@ -63,9 +63,14 @@ func SetPendingDialHooks(before, after func()) { func SetTestHookServerServe(fn func(*Server, net.Listener)) { testHookServerServe = fn } func NewTestTimeoutHandler(handler Handler, ch <-chan time.Time) Handler { + ctx, cancel := context.WithCancel(context.Background()) + go func() { + <-ch + cancel() + }() return &timeoutHandler{ handler: handler, - testTimeout: ch, + testContext: ctx, // (no body) } } @@ -206,3 +211,9 @@ func (s *Server) ExportAllConnsIdle() bool { func (r *Request) WithT(t *testing.T) *Request { return r.WithContext(context.WithValue(r.Context(), tLogKey{}, t.Logf)) } + +func ExportSetH2GoawayTimeout(d time.Duration) (restore func()) { + old := http2goAwayTimeout + http2goAwayTimeout = d + return func() { http2goAwayTimeout = old } +} diff --git a/libgo/go/net/http/fs.go b/libgo/go/net/http/fs.go index 5819334b5f4..ecad14ac1e4 100644 --- a/libgo/go/net/http/fs.go +++ b/libgo/go/net/http/fs.go @@ -98,12 +98,10 @@ type File interface { Stat() (os.FileInfo, error) } -func dirList(w ResponseWriter, f File) { +func dirList(w ResponseWriter, r *Request, f File) { dirs, err := f.Readdir(-1) if err != nil { - // TODO: log err.Error() to the Server.ErrorLog, once it's possible - // for a handler to get at its Server via the ResponseWriter. See - // Issue 12438. + logf(r, "http: error reading directory: %v", err) Error(w, "Error reading directory", StatusInternalServerError) return } @@ -319,7 +317,7 @@ func scanETag(s string) (etag string, remain string) { // Character values allowed in ETags. case c == 0x21 || c >= 0x23 && c <= 0x7E || c >= 0x80: case c == '"': - return string(s[:i+1]), s[i+1:] + return s[:i+1], s[i+1:] default: return "", "" } @@ -445,7 +443,7 @@ func checkIfModifiedSince(r *Request, modtime time.Time) condResult { } func checkIfRange(w ResponseWriter, r *Request, modtime time.Time) condResult { - if r.Method != "GET" { + if r.Method != "GET" && r.Method != "HEAD" { return condNone } ir := r.Header.get("If-Range") @@ -532,10 +530,8 @@ func checkPreconditions(w ResponseWriter, r *Request, modtime time.Time) (done b } rangeHeader = r.Header.get("Range") - if rangeHeader != "" { - if checkIfRange(w, r, modtime) == condFalse { - rangeHeader = "" - } + if rangeHeader != "" && checkIfRange(w, r, modtime) == condFalse { + rangeHeader = "" } return false, rangeHeader } @@ -615,7 +611,7 @@ func serveFile(w ResponseWriter, r *Request, fs FileSystem, name string, redirec return } w.Header().Set("Last-Modified", d.ModTime().UTC().Format(TimeFormat)) - dirList(w, f) + dirList(w, r, f) return } diff --git a/libgo/go/net/http/fs_test.go b/libgo/go/net/http/fs_test.go index f6eab0fcc31..fb8f9fe4c5c 100644 --- a/libgo/go/net/http/fs_test.go +++ b/libgo/go/net/http/fs_test.go @@ -895,6 +895,17 @@ func TestServeContent(t *testing.T) { wantContentRange: "bytes 0-4/8", wantLastMod: "Wed, 25 Jun 2014 17:12:18 GMT", }, + "range_with_modtime_mismatch": { + file: "testdata/style.css", + modtime: time.Date(2014, 6, 25, 17, 12, 18, 0 /* nanos */, time.UTC), + reqHeader: map[string]string{ + "Range": "bytes=0-4", + "If-Range": "Wed, 25 Jun 2014 17:12:19 GMT", + }, + wantStatus: StatusOK, + wantContentType: "text/css; charset=utf-8", + wantLastMod: "Wed, 25 Jun 2014 17:12:18 GMT", + }, "range_with_modtime_nanos": { file: "testdata/style.css", modtime: time.Date(2014, 6, 25, 17, 12, 18, 123 /* nanos */, time.UTC), @@ -937,8 +948,7 @@ func TestServeContent(t *testing.T) { reqHeader: map[string]string{ "If-Match": `"B"`, }, - wantStatus: 412, - wantContentType: "text/plain; charset=utf-8", + wantStatus: 412, }, "ifmatch_fails_on_weak_etag": { file: "testdata/style.css", @@ -946,8 +956,7 @@ func TestServeContent(t *testing.T) { reqHeader: map[string]string{ "If-Match": `W/"A"`, }, - wantStatus: 412, - wantContentType: "text/plain; charset=utf-8", + wantStatus: 412, }, "if_unmodified_since_true": { file: "testdata/style.css", @@ -965,9 +974,8 @@ func TestServeContent(t *testing.T) { reqHeader: map[string]string{ "If-Unmodified-Since": htmlModTime.Add(-2 * time.Second).UTC().Format(TimeFormat), }, - wantStatus: 412, - wantContentType: "text/plain; charset=utf-8", - wantLastMod: htmlModTime.UTC().Format(TimeFormat), + wantStatus: 412, + wantLastMod: htmlModTime.UTC().Format(TimeFormat), }, } for testName, tt := range tests { @@ -982,40 +990,46 @@ func TestServeContent(t *testing.T) { } else { content = tt.content } + for _, method := range []string{"GET", "HEAD"} { + //restore content in case it is consumed by previous method + if content, ok := content.(*strings.Reader); ok { + content.Seek(io.SeekStart, 0) + } - servec <- serveParam{ - name: filepath.Base(tt.file), - content: content, - modtime: tt.modtime, - etag: tt.serveETag, - contentType: tt.serveContentType, - } - req, err := NewRequest("GET", ts.URL, nil) - if err != nil { - t.Fatal(err) - } - for k, v := range tt.reqHeader { - req.Header.Set(k, v) - } + servec <- serveParam{ + name: filepath.Base(tt.file), + content: content, + modtime: tt.modtime, + etag: tt.serveETag, + contentType: tt.serveContentType, + } + req, err := NewRequest(method, ts.URL, nil) + if err != nil { + t.Fatal(err) + } + for k, v := range tt.reqHeader { + req.Header.Set(k, v) + } - c := ts.Client() - res, err := c.Do(req) - if err != nil { - t.Fatal(err) - } - io.Copy(ioutil.Discard, res.Body) - res.Body.Close() - if res.StatusCode != tt.wantStatus { - t.Errorf("test %q: status = %d; want %d", testName, res.StatusCode, tt.wantStatus) - } - if g, e := res.Header.Get("Content-Type"), tt.wantContentType; g != e { - t.Errorf("test %q: content-type = %q, want %q", testName, g, e) - } - if g, e := res.Header.Get("Content-Range"), tt.wantContentRange; g != e { - t.Errorf("test %q: content-range = %q, want %q", testName, g, e) - } - if g, e := res.Header.Get("Last-Modified"), tt.wantLastMod; g != e { - t.Errorf("test %q: last-modified = %q, want %q", testName, g, e) + c := ts.Client() + res, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + io.Copy(ioutil.Discard, res.Body) + res.Body.Close() + if res.StatusCode != tt.wantStatus { + t.Errorf("test %q using %q: got status = %d; want %d", testName, method, res.StatusCode, tt.wantStatus) + } + if g, e := res.Header.Get("Content-Type"), tt.wantContentType; g != e { + t.Errorf("test %q using %q: got content-type = %q, want %q", testName, method, g, e) + } + if g, e := res.Header.Get("Content-Range"), tt.wantContentRange; g != e { + t.Errorf("test %q using %q: got content-range = %q, want %q", testName, method, g, e) + } + if g, e := res.Header.Get("Last-Modified"), tt.wantLastMod; g != e { + t.Errorf("test %q using %q: got last-modified = %q, want %q", testName, method, g, e) + } } } } diff --git a/libgo/go/net/http/h2_bundle.go b/libgo/go/net/http/h2_bundle.go index 373f55098a3..f5a95084d24 100644 --- a/libgo/go/net/http/h2_bundle.go +++ b/libgo/go/net/http/h2_bundle.go @@ -30,6 +30,7 @@ import ( "io/ioutil" "log" "math" + mathrand "math/rand" "net" "net/http/httptrace" "net/textproto" @@ -3909,12 +3910,15 @@ func http2ConfigureServer(s *Server, conf *http2Server) error { } else if s.TLSConfig.CipherSuites != nil { // If they already provided a CipherSuite list, return // an error if it has a bad order or is missing - // ECDHE_RSA_WITH_AES_128_GCM_SHA256. - const requiredCipher = tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 + // ECDHE_RSA_WITH_AES_128_GCM_SHA256 or ECDHE_ECDSA_WITH_AES_128_GCM_SHA256. haveRequired := false sawBad := false for i, cs := range s.TLSConfig.CipherSuites { - if cs == requiredCipher { + switch cs { + case tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + // Alternative MTI cipher to not discourage ECDSA-only servers. + // See http://golang.org/cl/30721 for further information. + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: haveRequired = true } if http2isBadCipher(cs) { @@ -3924,7 +3928,7 @@ func http2ConfigureServer(s *Server, conf *http2Server) error { } } if !haveRequired { - return fmt.Errorf("http2: TLSConfig.CipherSuites is missing HTTP/2-required TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256") + return fmt.Errorf("http2: TLSConfig.CipherSuites is missing an HTTP/2-required AES_128_GCM_SHA256 cipher.") } } @@ -4341,7 +4345,7 @@ func (sc *http2serverConn) condlogf(err error, format string, args ...interface{ if err == nil { return } - if err == io.EOF || err == io.ErrUnexpectedEOF || http2isClosedConnError(err) { + if err == io.EOF || err == io.ErrUnexpectedEOF || http2isClosedConnError(err) || err == http2errPrefaceTimeout { // Boring, expected errors. sc.vlogf(format, args...) } else { @@ -4545,8 +4549,13 @@ func (sc *http2serverConn) serve() { } } - if sc.inGoAway && sc.curOpenStreams() == 0 && !sc.needToSendGoAway && !sc.writingFrame { - return + // Start the shutdown timer after sending a GOAWAY. When sending GOAWAY + // with no error code (graceful shutdown), don't start the timer until + // all open streams have been completed. + sentGoAway := sc.inGoAway && !sc.needToSendGoAway && !sc.writingFrame + gracefulShutdownComplete := sc.goAwayCode == http2ErrCodeNo && sc.curOpenStreams() == 0 + if sentGoAway && sc.shutdownTimer == nil && (sc.goAwayCode != http2ErrCodeNo || gracefulShutdownComplete) { + sc.shutDownIn(http2goAwayTimeout) } } } @@ -4583,8 +4592,11 @@ func (sc *http2serverConn) sendServeMsg(msg interface{}) { } } -// readPreface reads the ClientPreface greeting from the peer -// or returns an error on timeout or an invalid greeting. +var http2errPrefaceTimeout = errors.New("timeout waiting for client preface") + +// readPreface reads the ClientPreface greeting from the peer or +// returns errPrefaceTimeout on timeout, or an error if the greeting +// is invalid. func (sc *http2serverConn) readPreface() error { errc := make(chan error, 1) go func() { @@ -4602,7 +4614,7 @@ func (sc *http2serverConn) readPreface() error { defer timer.Stop() select { case <-timer.C: - return errors.New("timeout waiting for client preface") + return http2errPrefaceTimeout case err := <-errc: if err == nil { if http2VerboseLogs { @@ -4912,30 +4924,31 @@ func (sc *http2serverConn) startGracefulShutdown() { sc.shutdownOnce.Do(func() { sc.sendServeMsg(http2gracefulShutdownMsg) }) } +// After sending GOAWAY, the connection will close after goAwayTimeout. +// If we close the connection immediately after sending GOAWAY, there may +// be unsent data in our kernel receive buffer, which will cause the kernel +// to send a TCP RST on close() instead of a FIN. This RST will abort the +// connection immediately, whether or not the client had received the GOAWAY. +// +// Ideally we should delay for at least 1 RTT + epsilon so the client has +// a chance to read the GOAWAY and stop sending messages. Measuring RTT +// is hard, so we approximate with 1 second. See golang.org/issue/18701. +// +// This is a var so it can be shorter in tests, where all requests uses the +// loopback interface making the expected RTT very small. +// +// TODO: configurable? +var http2goAwayTimeout = 1 * time.Second + func (sc *http2serverConn) startGracefulShutdownInternal() { - sc.goAwayIn(http2ErrCodeNo, 0) + sc.goAway(http2ErrCodeNo) } func (sc *http2serverConn) goAway(code http2ErrCode) { sc.serveG.check() - var forceCloseIn time.Duration - if code != http2ErrCodeNo { - forceCloseIn = 250 * time.Millisecond - } else { - // TODO: configurable - forceCloseIn = 1 * time.Second - } - sc.goAwayIn(code, forceCloseIn) -} - -func (sc *http2serverConn) goAwayIn(code http2ErrCode, forceCloseIn time.Duration) { - sc.serveG.check() if sc.inGoAway { return } - if forceCloseIn != 0 { - sc.shutDownIn(forceCloseIn) - } sc.inGoAway = true sc.needToSendGoAway = true sc.goAwayCode = code @@ -6004,7 +6017,7 @@ func (rws *http2responseWriterState) writeChunk(p []byte) (n int, err error) { clen = strconv.Itoa(len(p)) } _, hasContentType := rws.snapHeader["Content-Type"] - if !hasContentType && http2bodyAllowedForStatus(rws.status) { + if !hasContentType && http2bodyAllowedForStatus(rws.status) && len(p) > 0 { ctype = DetectContentType(p) } var date string @@ -6172,7 +6185,26 @@ func (w *http2responseWriter) Header() Header { return rws.handlerHeader } +// checkWriteHeaderCode is a copy of net/http's checkWriteHeaderCode. +func http2checkWriteHeaderCode(code int) { + // Issue 22880: require valid WriteHeader status codes. + // For now we only enforce that it's three digits. + // In the future we might block things over 599 (600 and above aren't defined + // at http://httpwg.org/specs/rfc7231.html#status.codes) + // and we might block under 200 (once we have more mature 1xx support). + // But for now any three digits. + // + // We used to send "HTTP/1.1 000 0" on the wire in responses but there's + // no equivalent bogus thing we can realistically send in HTTP/2, + // so we'll consistently panic instead and help people find their bugs + // early. (We can't return an error from WriteHeader even if we wanted to.) + if code < 100 || code > 999 { + panic(fmt.Sprintf("invalid WriteHeader code %v", code)) + } +} + func (w *http2responseWriter) WriteHeader(code int) { + http2checkWriteHeaderCode(code) rws := w.rws if rws == nil { panic("WriteHeader called after Handler finished") @@ -6605,7 +6637,7 @@ type http2Transport struct { // MaxHeaderListSize is the http2 SETTINGS_MAX_HEADER_LIST_SIZE to // send in the initial settings frame. It is how many bytes - // of response headers are allow. Unlike the http2 spec, zero here + // of response headers are allowed. Unlike the http2 spec, zero here // means to use a default limit (currently 10MB). If you actually // want to advertise an ulimited value to the peer, Transport // interprets the highest possible value here (0xffffffff or 1<<32-1) @@ -6683,15 +6715,17 @@ type http2ClientConn struct { goAwayDebug string // goAway frame's debug data, retained as a string streams map[uint32]*http2clientStream // client-initiated nextStreamID uint32 + pendingRequests int // requests blocked and waiting to be sent because len(streams) == maxConcurrentStreams pings map[[8]byte]chan struct{} // in flight ping data to notification channel bw *bufio.Writer br *bufio.Reader fr *http2Framer lastActive time.Time // Settings from peer: (also guarded by mu) - maxFrameSize uint32 - maxConcurrentStreams uint32 - initialWindowSize uint32 + maxFrameSize uint32 + maxConcurrentStreams uint32 + peerMaxHeaderListSize uint64 + initialWindowSize uint32 hbuf bytes.Buffer // HPACK encoder writes into this henc *hpack.Encoder @@ -6735,35 +6769,45 @@ type http2clientStream struct { resTrailer *Header // client's Response.Trailer } -// awaitRequestCancel runs in its own goroutine and waits for the user -// to cancel a RoundTrip request, its context to expire, or for the -// request to be done (any way it might be removed from the cc.streams -// map: peer reset, successful completion, TCP connection breakage, -// etc) -func (cs *http2clientStream) awaitRequestCancel(req *Request) { +// awaitRequestCancel waits for the user to cancel a request or for the done +// channel to be signaled. A non-nil error is returned only if the request was +// canceled. +func http2awaitRequestCancel(req *Request, done <-chan struct{}) error { ctx := http2reqContext(req) if req.Cancel == nil && ctx.Done() == nil { - return + return nil } select { case <-req.Cancel: - cs.cancelStream() - cs.bufPipe.CloseWithError(http2errRequestCanceled) + return http2errRequestCanceled case <-ctx.Done(): + return ctx.Err() + case <-done: + return nil + } +} + +// awaitRequestCancel waits for the user to cancel a request, its context to +// expire, or for the request to be done (any way it might be removed from the +// cc.streams map: peer reset, successful completion, TCP connection breakage, +// etc). If the request is canceled, then cs will be canceled and closed. +func (cs *http2clientStream) awaitRequestCancel(req *Request) { + if err := http2awaitRequestCancel(req, cs.done); err != nil { cs.cancelStream() - cs.bufPipe.CloseWithError(ctx.Err()) - case <-cs.done: + cs.bufPipe.CloseWithError(err) } } func (cs *http2clientStream) cancelStream() { - cs.cc.mu.Lock() + cc := cs.cc + cc.mu.Lock() didReset := cs.didReset cs.didReset = true - cs.cc.mu.Unlock() + cc.mu.Unlock() if !didReset { - cs.cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil) + cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil) + cc.forgetStreamID(cs.ID) } } @@ -6780,6 +6824,13 @@ func (cs *http2clientStream) checkResetOrDone() error { } } +func (cs *http2clientStream) getStartedWrite() bool { + cc := cs.cc + cc.mu.Lock() + defer cc.mu.Unlock() + return cs.startedWrite +} + func (cs *http2clientStream) abortRequestBodyWrite(err error) { if err == nil { panic("nil error") @@ -6848,17 +6899,28 @@ func (t *http2Transport) RoundTripOpt(req *Request, opt http2RoundTripOpt) (*Res } addr := http2authorityAddr(req.URL.Scheme, req.URL.Host) - for { + for retry := 0; ; retry++ { cc, err := t.connPool().GetClientConn(req, addr) if err != nil { t.vlogf("http2: Transport failed to get client conn for %s: %v", addr, err) return nil, err } http2traceGotConn(req, cc) - res, err := cc.RoundTrip(req) - if err != nil { - if req, err = http2shouldRetryRequest(req, err); err == nil { - continue + res, gotErrAfterReqBodyWrite, err := cc.roundTrip(req) + if err != nil && retry <= 6 { + if req, err = http2shouldRetryRequest(req, err, gotErrAfterReqBodyWrite); err == nil { + // After the first retry, do exponential backoff with 10% jitter. + if retry == 0 { + continue + } + backoff := float64(uint(1) << (uint(retry) - 1)) + backoff += backoff * (0.1 * mathrand.Float64()) + select { + case <-time.After(time.Second * time.Duration(backoff)): + continue + case <-http2reqContext(req).Done(): + return nil, http2reqContext(req).Err() + } } } if err != nil { @@ -6879,43 +6941,50 @@ func (t *http2Transport) CloseIdleConnections() { } var ( - http2errClientConnClosed = errors.New("http2: client conn is closed") - http2errClientConnUnusable = errors.New("http2: client conn not usable") - - http2errClientConnGotGoAway = errors.New("http2: Transport received Server's graceful shutdown GOAWAY") - http2errClientConnGotGoAwayAfterSomeReqBody = errors.New("http2: Transport received Server's graceful shutdown GOAWAY; some request body already written") + http2errClientConnClosed = errors.New("http2: client conn is closed") + http2errClientConnUnusable = errors.New("http2: client conn not usable") + http2errClientConnGotGoAway = errors.New("http2: Transport received Server's graceful shutdown GOAWAY") ) // shouldRetryRequest is called by RoundTrip when a request fails to get // response headers. It is always called with a non-nil error. // It returns either a request to retry (either the same request, or a // modified clone), or an error if the request can't be replayed. -func http2shouldRetryRequest(req *Request, err error) (*Request, error) { - switch err { - default: +func http2shouldRetryRequest(req *Request, err error, afterBodyWrite bool) (*Request, error) { + if !http2canRetryError(err) { return nil, err - case http2errClientConnUnusable, http2errClientConnGotGoAway: + } + if !afterBodyWrite { return req, nil - case http2errClientConnGotGoAwayAfterSomeReqBody: - // If the Body is nil (or http.NoBody), it's safe to reuse - // this request and its Body. - if req.Body == nil || http2reqBodyIsNoBody(req.Body) { - return req, nil - } - // Otherwise we depend on the Request having its GetBody - // func defined. - getBody := http2reqGetBody(req) // Go 1.8: getBody = req.GetBody - if getBody == nil { - return nil, errors.New("http2: Transport: peer server initiated graceful shutdown after some of Request.Body was written; define Request.GetBody to avoid this error") - } - body, err := getBody() - if err != nil { - return nil, err - } - newReq := *req - newReq.Body = body - return &newReq, nil } + // If the Body is nil (or http.NoBody), it's safe to reuse + // this request and its Body. + if req.Body == nil || http2reqBodyIsNoBody(req.Body) { + return req, nil + } + // Otherwise we depend on the Request having its GetBody + // func defined. + getBody := http2reqGetBody(req) // Go 1.8: getBody = req.GetBody + if getBody == nil { + return nil, fmt.Errorf("http2: Transport: cannot retry err [%v] after Request.Body was written; define Request.GetBody to avoid this error", err) + } + body, err := getBody() + if err != nil { + return nil, err + } + newReq := *req + newReq.Body = body + return &newReq, nil +} + +func http2canRetryError(err error) bool { + if err == http2errClientConnUnusable || err == http2errClientConnGotGoAway { + return true + } + if se, ok := err.(http2StreamError); ok { + return se.Code == http2ErrCodeRefusedStream + } + return false } func (t *http2Transport) dialClientConn(addr string, singleUse bool) (*http2ClientConn, error) { @@ -6993,17 +7062,18 @@ func (t *http2Transport) NewClientConn(c net.Conn) (*http2ClientConn, error) { func (t *http2Transport) newClientConn(c net.Conn, singleUse bool) (*http2ClientConn, error) { cc := &http2ClientConn{ - t: t, - tconn: c, - readerDone: make(chan struct{}), - nextStreamID: 1, - maxFrameSize: 16 << 10, // spec default - initialWindowSize: 65535, // spec default - maxConcurrentStreams: 1000, // "infinite", per spec. 1000 seems good enough. - streams: make(map[uint32]*http2clientStream), - singleUse: singleUse, - wantSettingsAck: true, - pings: make(map[[8]byte]chan struct{}), + t: t, + tconn: c, + readerDone: make(chan struct{}), + nextStreamID: 1, + maxFrameSize: 16 << 10, // spec default + initialWindowSize: 65535, // spec default + maxConcurrentStreams: 1000, // "infinite", per spec. 1000 seems good enough. + peerMaxHeaderListSize: 0xffffffffffffffff, // "infinite", per spec. Use 2^64-1 instead. + streams: make(map[uint32]*http2clientStream), + singleUse: singleUse, + wantSettingsAck: true, + pings: make(map[[8]byte]chan struct{}), } if d := t.idleConnTimeout(); d != 0 { cc.idleTimeout = d @@ -7079,6 +7149,8 @@ func (cc *http2ClientConn) setGoAway(f *http2GoAwayFrame) { } } +// CanTakeNewRequest reports whether the connection can take a new request, +// meaning it has not been closed or received or sent a GOAWAY. func (cc *http2ClientConn) CanTakeNewRequest() bool { cc.mu.Lock() defer cc.mu.Unlock() @@ -7090,8 +7162,7 @@ func (cc *http2ClientConn) canTakeNewRequestLocked() bool { return false } return cc.goAway == nil && !cc.closed && - int64(len(cc.streams)+1) < int64(cc.maxConcurrentStreams) && - cc.nextStreamID < math.MaxInt32 + int64(cc.nextStreamID)+int64(cc.pendingRequests) < math.MaxInt32 } // onIdleTimeout is called from a time.AfterFunc goroutine. It will @@ -7223,8 +7294,13 @@ func http2actualContentLength(req *Request) int64 { } func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { + resp, _, err := cc.roundTrip(req) + return resp, err +} + +func (cc *http2ClientConn) roundTrip(req *Request) (res *Response, gotErrAfterReqBodyWrite bool, err error) { if err := http2checkConnHeaders(req); err != nil { - return nil, err + return nil, false, err } if cc.idleTimer != nil { cc.idleTimer.Stop() @@ -7232,15 +7308,14 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { trailers, err := http2commaSeparatedTrailers(req) if err != nil { - return nil, err + return nil, false, err } hasTrailers := trailers != "" cc.mu.Lock() - cc.lastActive = time.Now() - if cc.closed || !cc.canTakeNewRequestLocked() { + if err := cc.awaitOpenSlotForRequest(req); err != nil { cc.mu.Unlock() - return nil, http2errClientConnUnusable + return nil, false, err } body := req.Body @@ -7274,7 +7349,7 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { hdrs, err := cc.encodeHeaders(req, requestedGzip, trailers, contentLen) if err != nil { cc.mu.Unlock() - return nil, err + return nil, false, err } cs := cc.newStream() @@ -7286,7 +7361,7 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { cc.wmu.Lock() endStream := !hasBody && !hasTrailers - werr := cc.writeHeaders(cs.ID, endStream, hdrs) + werr := cc.writeHeaders(cs.ID, endStream, int(cc.maxFrameSize), hdrs) cc.wmu.Unlock() http2traceWroteHeaders(cs.trace) cc.mu.Unlock() @@ -7300,7 +7375,7 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { // Don't bother sending a RST_STREAM (our write already failed; // no need to keep writing) http2traceWroteRequest(cs.trace, werr) - return nil, werr + return nil, false, werr } var respHeaderTimer <-chan time.Time @@ -7319,7 +7394,7 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { bodyWritten := false ctx := http2reqContext(req) - handleReadLoopResponse := func(re http2resAndError) (*Response, error) { + handleReadLoopResponse := func(re http2resAndError) (*Response, bool, error) { res := re.res if re.err != nil || res.StatusCode > 299 { // On error or status code 3xx, 4xx, 5xx, etc abort any @@ -7335,19 +7410,12 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { cs.abortRequestBodyWrite(http2errStopReqBodyWrite) } if re.err != nil { - if re.err == http2errClientConnGotGoAway { - cc.mu.Lock() - if cs.startedWrite { - re.err = http2errClientConnGotGoAwayAfterSomeReqBody - } - cc.mu.Unlock() - } cc.forgetStreamID(cs.ID) - return nil, re.err + return nil, cs.getStartedWrite(), re.err } res.Request = req res.TLS = cc.tlsState - return res, nil + return res, false, nil } for { @@ -7355,42 +7423,42 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { case re := <-readLoopResCh: return handleReadLoopResponse(re) case <-respHeaderTimer: - cc.forgetStreamID(cs.ID) if !hasBody || bodyWritten { cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil) } else { bodyWriter.cancel() cs.abortRequestBodyWrite(http2errStopReqBodyWriteAndCancel) } - return nil, http2errTimeout + cc.forgetStreamID(cs.ID) + return nil, cs.getStartedWrite(), http2errTimeout case <-ctx.Done(): select { case re := <-readLoopResCh: return handleReadLoopResponse(re) default: } - cc.forgetStreamID(cs.ID) if !hasBody || bodyWritten { cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil) } else { bodyWriter.cancel() cs.abortRequestBodyWrite(http2errStopReqBodyWriteAndCancel) } - return nil, ctx.Err() + cc.forgetStreamID(cs.ID) + return nil, cs.getStartedWrite(), ctx.Err() case <-req.Cancel: select { case re := <-readLoopResCh: return handleReadLoopResponse(re) default: } - cc.forgetStreamID(cs.ID) if !hasBody || bodyWritten { cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil) } else { bodyWriter.cancel() cs.abortRequestBodyWrite(http2errStopReqBodyWriteAndCancel) } - return nil, http2errRequestCanceled + cc.forgetStreamID(cs.ID) + return nil, cs.getStartedWrite(), http2errRequestCanceled case <-cs.peerReset: select { case re := <-readLoopResCh: @@ -7400,7 +7468,7 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { // processResetStream already removed the // stream from the streams map; no need for // forgetStreamID. - return nil, cs.resetErr + return nil, cs.getStartedWrite(), cs.resetErr case err := <-bodyWriter.resc: // Prefer the read loop's response, if available. Issue 16102. select { @@ -7409,7 +7477,7 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { default: } if err != nil { - return nil, err + return nil, cs.getStartedWrite(), err } bodyWritten = true if d := cc.responseHeaderTimeout(); d != 0 { @@ -7421,14 +7489,52 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { } } +// awaitOpenSlotForRequest waits until len(streams) < maxConcurrentStreams. +// Must hold cc.mu. +func (cc *http2ClientConn) awaitOpenSlotForRequest(req *Request) error { + var waitingForConn chan struct{} + var waitingForConnErr error // guarded by cc.mu + for { + cc.lastActive = time.Now() + if cc.closed || !cc.canTakeNewRequestLocked() { + return http2errClientConnUnusable + } + if int64(len(cc.streams))+1 <= int64(cc.maxConcurrentStreams) { + if waitingForConn != nil { + close(waitingForConn) + } + return nil + } + // Unfortunately, we cannot wait on a condition variable and channel at + // the same time, so instead, we spin up a goroutine to check if the + // request is canceled while we wait for a slot to open in the connection. + if waitingForConn == nil { + waitingForConn = make(chan struct{}) + go func() { + if err := http2awaitRequestCancel(req, waitingForConn); err != nil { + cc.mu.Lock() + waitingForConnErr = err + cc.cond.Broadcast() + cc.mu.Unlock() + } + }() + } + cc.pendingRequests++ + cc.cond.Wait() + cc.pendingRequests-- + if waitingForConnErr != nil { + return waitingForConnErr + } + } +} + // requires cc.wmu be held -func (cc *http2ClientConn) writeHeaders(streamID uint32, endStream bool, hdrs []byte) error { +func (cc *http2ClientConn) writeHeaders(streamID uint32, endStream bool, maxFrameSize int, hdrs []byte) error { first := true // first frame written (HEADERS is first, then CONTINUATION) - frameSize := int(cc.maxFrameSize) for len(hdrs) > 0 && cc.werr == nil { chunk := hdrs - if len(chunk) > frameSize { - chunk = chunk[:frameSize] + if len(chunk) > maxFrameSize { + chunk = chunk[:maxFrameSize] } hdrs = hdrs[len(chunk):] endHeaders := len(hdrs) == 0 @@ -7536,17 +7642,26 @@ func (cs *http2clientStream) writeRequestBody(body io.Reader, bodyCloser io.Clos var trls []byte if hasTrailers { cc.mu.Lock() - defer cc.mu.Unlock() - trls = cc.encodeTrailers(req) + trls, err = cc.encodeTrailers(req) + cc.mu.Unlock() + if err != nil { + cc.writeStreamReset(cs.ID, http2ErrCodeInternal, err) + cc.forgetStreamID(cs.ID) + return err + } } + cc.mu.Lock() + maxFrameSize := int(cc.maxFrameSize) + cc.mu.Unlock() + cc.wmu.Lock() defer cc.wmu.Unlock() // Two ways to send END_STREAM: either with trailers, or // with an empty DATA frame. if len(trls) > 0 { - err = cc.writeHeaders(cs.ID, true, trls) + err = cc.writeHeaders(cs.ID, true, maxFrameSize, trls) } else { err = cc.fr.WriteData(cs.ID, true, nil) } @@ -7640,62 +7755,86 @@ func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trail } } - // 8.1.2.3 Request Pseudo-Header Fields - // The :path pseudo-header field includes the path and query parts of the - // target URI (the path-absolute production and optionally a '?' character - // followed by the query production (see Sections 3.3 and 3.4 of - // [RFC3986]). - cc.writeHeader(":authority", host) - cc.writeHeader(":method", req.Method) - if req.Method != "CONNECT" { - cc.writeHeader(":path", path) - cc.writeHeader(":scheme", req.URL.Scheme) - } - if trailers != "" { - cc.writeHeader("trailer", trailers) - } - - var didUA bool - for k, vv := range req.Header { - lowKey := strings.ToLower(k) - switch lowKey { - case "host", "content-length": - // Host is :authority, already sent. - // Content-Length is automatic, set below. - continue - case "connection", "proxy-connection", "transfer-encoding", "upgrade", "keep-alive": - // Per 8.1.2.2 Connection-Specific Header - // Fields, don't send connection-specific - // fields. We have already checked if any - // are error-worthy so just ignore the rest. - continue - case "user-agent": - // Match Go's http1 behavior: at most one - // User-Agent. If set to nil or empty string, - // then omit it. Otherwise if not mentioned, - // include the default (below). - didUA = true - if len(vv) < 1 { + enumerateHeaders := func(f func(name, value string)) { + // 8.1.2.3 Request Pseudo-Header Fields + // The :path pseudo-header field includes the path and query parts of the + // target URI (the path-absolute production and optionally a '?' character + // followed by the query production (see Sections 3.3 and 3.4 of + // [RFC3986]). + f(":authority", host) + f(":method", req.Method) + if req.Method != "CONNECT" { + f(":path", path) + f(":scheme", req.URL.Scheme) + } + if trailers != "" { + f("trailer", trailers) + } + + var didUA bool + for k, vv := range req.Header { + if strings.EqualFold(k, "host") || strings.EqualFold(k, "content-length") { + // Host is :authority, already sent. + // Content-Length is automatic, set below. continue - } - vv = vv[:1] - if vv[0] == "" { + } else if strings.EqualFold(k, "connection") || strings.EqualFold(k, "proxy-connection") || + strings.EqualFold(k, "transfer-encoding") || strings.EqualFold(k, "upgrade") || + strings.EqualFold(k, "keep-alive") { + // Per 8.1.2.2 Connection-Specific Header + // Fields, don't send connection-specific + // fields. We have already checked if any + // are error-worthy so just ignore the rest. continue + } else if strings.EqualFold(k, "user-agent") { + // Match Go's http1 behavior: at most one + // User-Agent. If set to nil or empty string, + // then omit it. Otherwise if not mentioned, + // include the default (below). + didUA = true + if len(vv) < 1 { + continue + } + vv = vv[:1] + if vv[0] == "" { + continue + } + + } + + for _, v := range vv { + f(k, v) } } - for _, v := range vv { - cc.writeHeader(lowKey, v) + if http2shouldSendReqContentLength(req.Method, contentLength) { + f("content-length", strconv.FormatInt(contentLength, 10)) + } + if addGzipHeader { + f("accept-encoding", "gzip") + } + if !didUA { + f("user-agent", http2defaultUserAgent) } } - if http2shouldSendReqContentLength(req.Method, contentLength) { - cc.writeHeader("content-length", strconv.FormatInt(contentLength, 10)) - } - if addGzipHeader { - cc.writeHeader("accept-encoding", "gzip") - } - if !didUA { - cc.writeHeader("user-agent", http2defaultUserAgent) + + // Do a first pass over the headers counting bytes to ensure + // we don't exceed cc.peerMaxHeaderListSize. This is done as a + // separate pass before encoding the headers to prevent + // modifying the hpack state. + hlSize := uint64(0) + enumerateHeaders(func(name, value string) { + hf := hpack.HeaderField{Name: name, Value: value} + hlSize += uint64(hf.Size()) + }) + + if hlSize > cc.peerMaxHeaderListSize { + return nil, http2errRequestHeaderListSize } + + // Header list size is ok. Write the headers. + enumerateHeaders(func(name, value string) { + cc.writeHeader(strings.ToLower(name), value) + }) + return cc.hbuf.Bytes(), nil } @@ -7722,17 +7861,29 @@ func http2shouldSendReqContentLength(method string, contentLength int64) bool { } // requires cc.mu be held. -func (cc *http2ClientConn) encodeTrailers(req *Request) []byte { +func (cc *http2ClientConn) encodeTrailers(req *Request) ([]byte, error) { cc.hbuf.Reset() + + hlSize := uint64(0) + for k, vv := range req.Trailer { + for _, v := range vv { + hf := hpack.HeaderField{Name: k, Value: v} + hlSize += uint64(hf.Size()) + } + } + if hlSize > cc.peerMaxHeaderListSize { + return nil, http2errRequestHeaderListSize + } + for k, vv := range req.Trailer { - // Transfer-Encoding, etc.. have already been filter at the + // Transfer-Encoding, etc.. have already been filtered at the // start of RoundTrip lowKey := strings.ToLower(k) for _, v := range vv { cc.writeHeader(lowKey, v) } } - return cc.hbuf.Bytes() + return cc.hbuf.Bytes(), nil } func (cc *http2ClientConn) writeHeader(name, value string) { @@ -7780,7 +7931,9 @@ func (cc *http2ClientConn) streamByID(id uint32, andRemove bool) *http2clientStr cc.idleTimer.Reset(cc.idleTimeout) } close(cs.done) - cc.cond.Broadcast() // wake up checkResetOrDone via clientStream.awaitFlowControl + // Wake up checkResetOrDone via clientStream.awaitFlowControl and + // wake up RoundTrip if there is a pending request. + cc.cond.Broadcast() } return cs } @@ -7788,17 +7941,12 @@ func (cc *http2ClientConn) streamByID(id uint32, andRemove bool) *http2clientStr // clientConnReadLoop is the state owned by the clientConn's frame-reading readLoop. type http2clientConnReadLoop struct { cc *http2ClientConn - activeRes map[uint32]*http2clientStream // keyed by streamID closeWhenIdle bool } // readLoop runs in its own goroutine and reads and dispatches frames. func (cc *http2ClientConn) readLoop() { - rl := &http2clientConnReadLoop{ - cc: cc, - activeRes: make(map[uint32]*http2clientStream), - } - + rl := &http2clientConnReadLoop{cc: cc} defer rl.cleanup() cc.readerErr = rl.run() if ce, ok := cc.readerErr.(http2ConnectionError); ok { @@ -7853,10 +8001,8 @@ func (rl *http2clientConnReadLoop) cleanup() { } else if err == io.EOF { err = io.ErrUnexpectedEOF } - for _, cs := range rl.activeRes { - cs.bufPipe.CloseWithError(err) - } for _, cs := range cc.streams { + cs.bufPipe.CloseWithError(err) // no-op if already closed select { case cs.resc <- http2resAndError{err: err}: default: @@ -7879,8 +8025,9 @@ func (rl *http2clientConnReadLoop) run() error { cc.vlogf("http2: Transport readFrame error on conn %p: (%T) %v", cc, err, err) } if se, ok := err.(http2StreamError); ok { - if cs := cc.streamByID(se.StreamID, true /*ended; remove it*/); cs != nil { + if cs := cc.streamByID(se.StreamID, false); cs != nil { cs.cc.writeStreamReset(cs.ID, se.Code, err) + cs.cc.forgetStreamID(cs.ID) if se.Cause == nil { se.Cause = cc.fr.errDetail } @@ -7933,7 +8080,7 @@ func (rl *http2clientConnReadLoop) run() error { } return err } - if rl.closeWhenIdle && gotReply && maybeIdle && len(rl.activeRes) == 0 { + if rl.closeWhenIdle && gotReply && maybeIdle { cc.closeIfIdle() } } @@ -7941,13 +8088,31 @@ func (rl *http2clientConnReadLoop) run() error { func (rl *http2clientConnReadLoop) processHeaders(f *http2MetaHeadersFrame) error { cc := rl.cc - cs := cc.streamByID(f.StreamID, f.StreamEnded()) + cs := cc.streamByID(f.StreamID, false) if cs == nil { // We'd get here if we canceled a request while the // server had its response still in flight. So if this // was just something we canceled, ignore it. return nil } + if f.StreamEnded() { + // Issue 20521: If the stream has ended, streamByID() causes + // clientStream.done to be closed, which causes the request's bodyWriter + // to be closed with an errStreamClosed, which may be received by + // clientConn.RoundTrip before the result of processing these headers. + // Deferring stream closure allows the header processing to occur first. + // clientConn.RoundTrip may still receive the bodyWriter error first, but + // the fix for issue 16102 prioritises any response. + // + // Issue 22413: If there is no request body, we should close the + // stream before writing to cs.resc so that the stream is closed + // immediately once RoundTrip returns. + if cs.req.Body != nil { + defer cc.forgetStreamID(f.StreamID) + } else { + cc.forgetStreamID(f.StreamID) + } + } if !cs.firstByte { if cs.trace != nil { // TODO(bradfitz): move first response byte earlier, @@ -7971,6 +8136,7 @@ func (rl *http2clientConnReadLoop) processHeaders(f *http2MetaHeadersFrame) erro } // Any other error type is a stream error. cs.cc.writeStreamReset(f.StreamID, http2ErrCodeProtocol, err) + cc.forgetStreamID(cs.ID) cs.resc <- http2resAndError{err: err} return nil // return nil from process* funcs to keep conn alive } @@ -7978,9 +8144,6 @@ func (rl *http2clientConnReadLoop) processHeaders(f *http2MetaHeadersFrame) erro // (nil, nil) special case. See handleResponse docs. return nil } - if res.Body != http2noBody { - rl.activeRes[cs.ID] = cs - } cs.resTrailer = &res.Trailer cs.resc <- http2resAndError{res: res} return nil @@ -8000,11 +8163,11 @@ func (rl *http2clientConnReadLoop) handleResponse(cs *http2clientStream, f *http status := f.PseudoValue("status") if status == "" { - return nil, errors.New("missing status pseudo header") + return nil, errors.New("malformed response from server: missing status pseudo header") } statusCode, err := strconv.Atoi(status) if err != nil { - return nil, errors.New("malformed non-numeric status pseudo header") + return nil, errors.New("malformed response from server: malformed non-numeric status pseudo header") } if statusCode == 100 { @@ -8202,6 +8365,7 @@ func (b http2transportResponseBody) Close() error { } cs.bufPipe.BreakWithError(http2errClosedResponseBody) + cc.forgetStreamID(cs.ID) return nil } @@ -8236,7 +8400,23 @@ func (rl *http2clientConnReadLoop) processData(f *http2DataFrame) error { } return nil } + if !cs.firstByte { + cc.logf("protocol error: received DATA before a HEADERS frame") + rl.endStreamError(cs, http2StreamError{ + StreamID: f.StreamID, + Code: http2ErrCodeProtocol, + }) + return nil + } if f.Length > 0 { + if cs.req.Method == "HEAD" && len(data) > 0 { + cc.logf("protocol error: received DATA on a HEAD request") + rl.endStreamError(cs, http2StreamError{ + StreamID: f.StreamID, + Code: http2ErrCodeProtocol, + }) + return nil + } // Check connection-level flow control. cc.mu.Lock() if cs.inflow.available() >= int32(f.Length) { @@ -8298,11 +8478,10 @@ func (rl *http2clientConnReadLoop) endStreamError(cs *http2clientStream, err err err = io.EOF code = cs.copyTrailers } - cs.bufPipe.closeWithErrorAndCode(err, code) - delete(rl.activeRes, cs.ID) if http2isConnectionCloseRequest(cs.req) { rl.closeWhenIdle = true } + cs.bufPipe.closeWithErrorAndCode(err, code) select { case cs.resc <- http2resAndError{err: err}: @@ -8350,6 +8529,8 @@ func (rl *http2clientConnReadLoop) processSettings(f *http2SettingsFrame) error cc.maxFrameSize = s.Val case http2SettingMaxConcurrentStreams: cc.maxConcurrentStreams = s.Val + case http2SettingMaxHeaderListSize: + cc.peerMaxHeaderListSize = uint64(s.Val) case http2SettingInitialWindowSize: // Values above the maximum flow-control // window size of 2^31-1 MUST be treated as a @@ -8427,7 +8608,6 @@ func (rl *http2clientConnReadLoop) processResetStream(f *http2RSTStreamFrame) er cs.bufPipe.CloseWithError(err) cs.cc.cond.Broadcast() // wake up checkResetOrDone via clientStream.awaitFlowControl } - delete(rl.activeRes, cs.ID) return nil } @@ -8516,6 +8696,7 @@ func (cc *http2ClientConn) writeStreamReset(streamID uint32, code http2ErrCode, var ( http2errResponseHeaderListSize = errors.New("http2: response header list larger than advertised limit") + http2errRequestHeaderListSize = errors.New("http2: request header list larger than peer's advertised limit") http2errPseudoTrailers = errors.New("http2: invalid pseudo header in trailers") ) @@ -8741,11 +8922,7 @@ type http2writeGoAway struct { func (p *http2writeGoAway) writeFrame(ctx http2writeContext) error { err := ctx.Framer().WriteGoAway(p.maxStreamID, p.code, nil) - if p.code != 0 { - ctx.Flush() // ignore error: we're hanging up on them anyway - time.Sleep(50 * time.Millisecond) - ctx.CloseConn() - } + ctx.Flush() // ignore error: we're hanging up on them anyway return err } diff --git a/libgo/go/net/http/header.go b/libgo/go/net/http/header.go index 832169247fe..622ad289636 100644 --- a/libgo/go/net/http/header.go +++ b/libgo/go/net/http/header.go @@ -156,6 +156,7 @@ func (h Header) WriteSubset(w io.Writer, exclude map[string]bool) error { v = textproto.TrimString(v) for _, s := range []string{kv.key, ": ", v, "\r\n"} { if _, err := ws.WriteString(s); err != nil { + headerSorterPool.Put(sorter) return err } } diff --git a/libgo/go/net/http/httputil/dump_test.go b/libgo/go/net/http/httputil/dump_test.go index f881020fef7..5703a7fb866 100644 --- a/libgo/go/net/http/httputil/dump_test.go +++ b/libgo/go/net/http/httputil/dump_test.go @@ -27,7 +27,6 @@ type dumpTest struct { } var dumpTests = []dumpTest{ - // HTTP/1.1 => chunked coding; body; empty trailer { Req: http.Request{ @@ -214,7 +213,6 @@ func TestDumpRequest(t *testing.T) { t.Fatalf("Test %d: unsupported Body of %T", i, tt.Body) } } - setBody() if tt.Req.Header == nil { tt.Req.Header = make(http.Header) } diff --git a/libgo/go/net/http/httputil/reverseproxy.go b/libgo/go/net/http/httputil/reverseproxy.go index 0d514f529ba..aa22d5a2fdd 100644 --- a/libgo/go/net/http/httputil/reverseproxy.go +++ b/libgo/go/net/http/httputil/reverseproxy.go @@ -169,15 +169,7 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { p.Director(outreq) outreq.Close = false - // Remove hop-by-hop headers listed in the "Connection" header. - // See RFC 2616, section 14.10. - if c := outreq.Header.Get("Connection"); c != "" { - for _, f := range strings.Split(c, ",") { - if f = strings.TrimSpace(f); f != "" { - outreq.Header.Del(f) - } - } - } + removeConnectionHeaders(outreq.Header) // Remove hop-by-hop headers to the backend. Especially // important is "Connection" because we want a persistent @@ -199,32 +191,30 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } res, err := transport.RoundTrip(outreq) - if err != nil { - p.logf("http: proxy error: %v", err) - rw.WriteHeader(http.StatusBadGateway) - return - } - - // Remove hop-by-hop headers listed in the - // "Connection" header of the response. - if c := res.Header.Get("Connection"); c != "" { - for _, f := range strings.Split(c, ",") { - if f = strings.TrimSpace(f); f != "" { - res.Header.Del(f) - } + if res == nil { + res = &http.Response{ + StatusCode: http.StatusBadGateway, + Body: http.NoBody, } } + removeConnectionHeaders(res.Header) + for _, h := range hopHeaders { res.Header.Del(h) } if p.ModifyResponse != nil { - if err := p.ModifyResponse(res); err != nil { + if err != nil { p.logf("http: proxy error: %v", err) - rw.WriteHeader(http.StatusBadGateway) - return } + err = p.ModifyResponse(res) + } + if err != nil { + p.logf("http: proxy error: %v", err) + rw.WriteHeader(http.StatusBadGateway) + res.Body.Close() + return } copyHeader(rw.Header(), res.Header) @@ -265,6 +255,18 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } } +// removeConnectionHeaders removes hop-by-hop headers listed in the "Connection" header of h. +// See RFC 2616, section 14.10. +func removeConnectionHeaders(h http.Header) { + if c := h.Get("Connection"); c != "" { + for _, f := range strings.Split(c, ",") { + if f = strings.TrimSpace(f); f != "" { + h.Del(f) + } + } + } +} + func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) { if p.FlushInterval != 0 { if wf, ok := dst.(writeFlusher); ok { diff --git a/libgo/go/net/http/httputil/reverseproxy_test.go b/libgo/go/net/http/httputil/reverseproxy_test.go index 37a9992375d..822828e5c0d 100644 --- a/libgo/go/net/http/httputil/reverseproxy_test.go +++ b/libgo/go/net/http/httputil/reverseproxy_test.go @@ -631,6 +631,35 @@ func TestReverseProxyModifyResponse(t *testing.T) { } } +// Issue 21255. Test ModifyResponse when an error from transport.RoundTrip +// occurs, and that the proxy returns StatusOK. +func TestReverseProxyModifyResponse_OnError(t *testing.T) { + // Always returns an error + errBackend := httptest.NewUnstartedServer(nil) + errBackend.Config.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests + defer errBackend.Close() + + rpURL, _ := url.Parse(errBackend.URL) + rproxy := NewSingleHostReverseProxy(rpURL) + rproxy.ModifyResponse = func(resp *http.Response) error { + // Will be set for a non-nil error + resp.StatusCode = http.StatusOK + return nil + } + + frontend := httptest.NewServer(rproxy) + defer frontend.Close() + + resp, err := http.Get(frontend.URL) + if err != nil { + t.Fatalf("failed to reach proxy: %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("err != nil: got res.StatusCode %d; expected %d", resp.StatusCode, http.StatusOK) + } + resp.Body.Close() +} + // Issue 16659: log errors from short read func TestReverseProxy_CopyBuffer(t *testing.T) { backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -769,3 +798,47 @@ type roundTripperFunc func(req *http.Request) (*http.Response, error) func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { return fn(req) } + +func TestModifyResponseClosesBody(t *testing.T) { + req, _ := http.NewRequest("GET", "http://foo.tld/", nil) + req.RemoteAddr = "1.2.3.4:56789" + closeCheck := new(checkCloser) + logBuf := new(bytes.Buffer) + outErr := errors.New("ModifyResponse error") + rp := &ReverseProxy{ + Director: func(req *http.Request) {}, + Transport: &staticTransport{&http.Response{ + StatusCode: 200, + Body: closeCheck, + }}, + ErrorLog: log.New(logBuf, "", 0), + ModifyResponse: func(*http.Response) error { + return outErr + }, + } + rec := httptest.NewRecorder() + rp.ServeHTTP(rec, req) + res := rec.Result() + if g, e := res.StatusCode, http.StatusBadGateway; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } + if !closeCheck.closed { + t.Errorf("body should have been closed") + } + if g, e := logBuf.String(), outErr.Error(); !strings.Contains(g, e) { + t.Errorf("ErrorLog %q does not contain %q", g, e) + } +} + +type checkCloser struct { + closed bool +} + +func (cc *checkCloser) Close() error { + cc.closed = true + return nil +} + +func (cc *checkCloser) Read(b []byte) (int, error) { + return len(b), nil +} diff --git a/libgo/go/net/http/pprof/pprof.go b/libgo/go/net/http/pprof/pprof.go index 12c7599ab0f..21992d62da2 100644 --- a/libgo/go/net/http/pprof/pprof.go +++ b/libgo/go/net/http/pprof/pprof.go @@ -69,11 +69,11 @@ import ( ) func init() { - http.Handle("/debug/pprof/", http.HandlerFunc(Index)) - http.Handle("/debug/pprof/cmdline", http.HandlerFunc(Cmdline)) - http.Handle("/debug/pprof/profile", http.HandlerFunc(Profile)) - http.Handle("/debug/pprof/symbol", http.HandlerFunc(Symbol)) - http.Handle("/debug/pprof/trace", http.HandlerFunc(Trace)) + http.HandleFunc("/debug/pprof/", Index) + http.HandleFunc("/debug/pprof/cmdline", Cmdline) + http.HandleFunc("/debug/pprof/profile", Profile) + http.HandleFunc("/debug/pprof/symbol", Symbol) + http.HandleFunc("/debug/pprof/trace", Trace) } // Cmdline responds with the running program's diff --git a/libgo/go/net/http/readrequest_test.go b/libgo/go/net/http/readrequest_test.go index 28a148b9acb..22a9c2ef4b2 100644 --- a/libgo/go/net/http/readrequest_test.go +++ b/libgo/go/net/http/readrequest_test.go @@ -454,6 +454,14 @@ abc`)}, {"smuggle_content_len_head", reqBytes(`HEAD / HTTP/1.1 Host: foo Content-Length: 5`)}, + + // golang.org/issue/22464 + {"leading_space_in_header", reqBytes(`HEAD / HTTP/1.1 + Host: foo +Content-Length: 5`)}, + {"leading_tab_in_header", reqBytes(`HEAD / HTTP/1.1 +\tHost: foo +Content-Length: 5`)}, } func TestReadRequest_Bad(t *testing.T) { diff --git a/libgo/go/net/http/request.go b/libgo/go/net/http/request.go index 13f367c1a8f..870af85e04a 100644 --- a/libgo/go/net/http/request.go +++ b/libgo/go/net/http/request.go @@ -490,8 +490,8 @@ var errMissingHost = errors.New("http: Request.Write on Request with no Host or // extraHeaders may be nil // waitForContinue may be nil -func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, waitForContinue func() bool) (err error) { - trace := httptrace.ContextClientTrace(req.Context()) +func (r *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, waitForContinue func() bool) (err error) { + trace := httptrace.ContextClientTrace(r.Context()) if trace != nil && trace.WroteRequest != nil { defer func() { trace.WroteRequest(httptrace.WroteRequestInfo{ @@ -504,12 +504,12 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, wai // is not given, use the host from the request URL. // // Clean the host, in case it arrives with unexpected stuff in it. - host := cleanHost(req.Host) + host := cleanHost(r.Host) if host == "" { - if req.URL == nil { + if r.URL == nil { return errMissingHost } - host = cleanHost(req.URL.Host) + host = cleanHost(r.URL.Host) } // According to RFC 6874, an HTTP client, proxy, or other @@ -517,10 +517,10 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, wai // to an outgoing URI. host = removeZone(host) - ruri := req.URL.RequestURI() - if usingProxy && req.URL.Scheme != "" && req.URL.Opaque == "" { - ruri = req.URL.Scheme + "://" + host + ruri - } else if req.Method == "CONNECT" && req.URL.Path == "" { + ruri := r.URL.RequestURI() + if usingProxy && r.URL.Scheme != "" && r.URL.Opaque == "" { + ruri = r.URL.Scheme + "://" + host + ruri + } else if r.Method == "CONNECT" && r.URL.Path == "" { // CONNECT requests normally give just the host and port, not a full URL. ruri = host } @@ -536,7 +536,7 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, wai w = bw } - _, err = fmt.Fprintf(w, "%s %s HTTP/1.1\r\n", valueOrDefault(req.Method, "GET"), ruri) + _, err = fmt.Fprintf(w, "%s %s HTTP/1.1\r\n", valueOrDefault(r.Method, "GET"), ruri) if err != nil { return err } @@ -550,8 +550,8 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, wai // Use the defaultUserAgent unless the Header contains one, which // may be blank to not send the header. userAgent := defaultUserAgent - if _, ok := req.Header["User-Agent"]; ok { - userAgent = req.Header.Get("User-Agent") + if _, ok := r.Header["User-Agent"]; ok { + userAgent = r.Header.Get("User-Agent") } if userAgent != "" { _, err = fmt.Fprintf(w, "User-Agent: %s\r\n", userAgent) @@ -561,7 +561,7 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, wai } // Process Body,ContentLength,Close,Trailer - tw, err := newTransferWriter(req) + tw, err := newTransferWriter(r) if err != nil { return err } @@ -570,7 +570,7 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, wai return err } - err = req.Header.WriteSubset(w, reqWriteExcludeHeader) + err = r.Header.WriteSubset(w, reqWriteExcludeHeader) if err != nil { return err } @@ -603,7 +603,7 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, wai trace.Wait100Continue() } if !waitForContinue() { - req.closeBody() + r.closeBody() return nil } } diff --git a/libgo/go/net/http/response.go b/libgo/go/net/http/response.go index 0357b605023..4c614bfab0b 100644 --- a/libgo/go/net/http/response.go +++ b/libgo/go/net/http/response.go @@ -27,6 +27,9 @@ var respExcludeHeader = map[string]bool{ // Response represents the response from an HTTP request. // +// The Client and Transport return Responses from servers once +// the response headers have been received. The response body +// is streamed on demand as the Body field is read. type Response struct { Status string // e.g. "200 OK" StatusCode int // e.g. 200 @@ -47,6 +50,10 @@ type Response struct { // Body represents the response body. // + // The response body is streamed on demand as the Body field + // is read. If the network connection fails or the server + // terminates the response, Body.Read calls return an error. + // // The http Client and Transport guarantee that Body is always // non-nil, even on responses without a body or responses with // a zero-length body. It is the caller's responsibility to diff --git a/libgo/go/net/http/response_test.go b/libgo/go/net/http/response_test.go index f1a50bd5989..1ea19619fee 100644 --- a/libgo/go/net/http/response_test.go +++ b/libgo/go/net/http/response_test.go @@ -816,7 +816,6 @@ func TestReadResponseErrors(t *testing.T) { type testCase struct { name string // optional, defaults to in in string - header Header wantErr interface{} // nil, err value, or string substring } @@ -842,22 +841,21 @@ func TestReadResponseErrors(t *testing.T) { } } - contentLength := func(status, body string, wantErr interface{}, header Header) testCase { + contentLength := func(status, body string, wantErr interface{}) testCase { return testCase{ name: fmt.Sprintf("status %q %q", status, body), in: fmt.Sprintf("HTTP/1.1 %s\r\n%s", status, body), wantErr: wantErr, - header: header, } } errMultiCL := "message cannot contain multiple Content-Length headers" tests := []testCase{ - {"", "", nil, io.ErrUnexpectedEOF}, - {"", "HTTP/1.1 301 Moved Permanently\r\nFoo: bar", nil, io.ErrUnexpectedEOF}, - {"", "HTTP/1.1", nil, "malformed HTTP response"}, - {"", "HTTP/2.0", nil, "malformed HTTP response"}, + {"", "", io.ErrUnexpectedEOF}, + {"", "HTTP/1.1 301 Moved Permanently\r\nFoo: bar", io.ErrUnexpectedEOF}, + {"", "HTTP/1.1", "malformed HTTP response"}, + {"", "HTTP/2.0", "malformed HTTP response"}, status("20X Unknown", true), status("abcd Unknown", true), status("二百/两百 OK", true), @@ -883,18 +881,22 @@ func TestReadResponseErrors(t *testing.T) { version("HTTP/1", true), version("http/1.1", true), - contentLength("200 OK", "Content-Length: 10\r\nContent-Length: 7\r\n\r\nGopher hey\r\n", errMultiCL, nil), - contentLength("200 OK", "Content-Length: 7\r\nContent-Length: 7\r\n\r\nGophers\r\n", nil, Header{"Content-Length": {"7"}}), - contentLength("201 OK", "Content-Length: 0\r\nContent-Length: 7\r\n\r\nGophers\r\n", errMultiCL, nil), - contentLength("300 OK", "Content-Length: 0\r\nContent-Length: 0 \r\n\r\nGophers\r\n", nil, Header{"Content-Length": {"0"}}), - contentLength("200 OK", "Content-Length:\r\nContent-Length:\r\n\r\nGophers\r\n", nil, nil), - contentLength("206 OK", "Content-Length:\r\nContent-Length: 0 \r\nConnection: close\r\n\r\nGophers\r\n", errMultiCL, nil), + contentLength("200 OK", "Content-Length: 10\r\nContent-Length: 7\r\n\r\nGopher hey\r\n", errMultiCL), + contentLength("200 OK", "Content-Length: 7\r\nContent-Length: 7\r\n\r\nGophers\r\n", nil), + contentLength("201 OK", "Content-Length: 0\r\nContent-Length: 7\r\n\r\nGophers\r\n", errMultiCL), + contentLength("300 OK", "Content-Length: 0\r\nContent-Length: 0 \r\n\r\nGophers\r\n", nil), + contentLength("200 OK", "Content-Length:\r\nContent-Length:\r\n\r\nGophers\r\n", nil), + contentLength("206 OK", "Content-Length:\r\nContent-Length: 0 \r\nConnection: close\r\n\r\nGophers\r\n", errMultiCL), // multiple content-length headers for 204 and 304 should still be checked - contentLength("204 OK", "Content-Length: 7\r\nContent-Length: 8\r\n\r\n", errMultiCL, nil), - contentLength("204 OK", "Content-Length: 3\r\nContent-Length: 3\r\n\r\n", nil, nil), - contentLength("304 OK", "Content-Length: 880\r\nContent-Length: 1\r\n\r\n", errMultiCL, nil), - contentLength("304 OK", "Content-Length: 961\r\nContent-Length: 961\r\n\r\n", nil, nil), + contentLength("204 OK", "Content-Length: 7\r\nContent-Length: 8\r\n\r\n", errMultiCL), + contentLength("204 OK", "Content-Length: 3\r\nContent-Length: 3\r\n\r\n", nil), + contentLength("304 OK", "Content-Length: 880\r\nContent-Length: 1\r\n\r\n", errMultiCL), + contentLength("304 OK", "Content-Length: 961\r\nContent-Length: 961\r\n\r\n", nil), + + // golang.org/issue/22464 + {"leading space in header", "HTTP/1.1 200 OK\r\n Content-type: text/html\r\nFoo: bar\r\n\r\n", "malformed MIME"}, + {"leading tab in header", "HTTP/1.1 200 OK\r\n\tContent-type: text/html\r\nFoo: bar\r\n\r\n", "malformed MIME"}, } for i, tt := range tests { diff --git a/libgo/go/net/http/serve_test.go b/libgo/go/net/http/serve_test.go index 7137599c42e..1ffa4115009 100644 --- a/libgo/go/net/http/serve_test.go +++ b/libgo/go/net/http/serve_test.go @@ -461,6 +461,68 @@ func TestMuxRedirectLeadingSlashes(t *testing.T) { } } +// Test that the special cased "/route" redirect +// implicitly created by a registered "/route/" +// properly sets the query string in the redirect URL. +// See Issue 17841. +func TestServeWithSlashRedirectKeepsQueryString(t *testing.T) { + setParallel(t) + defer afterTest(t) + + writeBackQuery := func(w ResponseWriter, r *Request) { + fmt.Fprintf(w, "%s", r.URL.RawQuery) + } + + mux := NewServeMux() + mux.HandleFunc("/testOne", writeBackQuery) + mux.HandleFunc("/testTwo/", writeBackQuery) + mux.HandleFunc("/testThree", writeBackQuery) + mux.HandleFunc("/testThree/", func(w ResponseWriter, r *Request) { + fmt.Fprintf(w, "%s:bar", r.URL.RawQuery) + }) + + ts := httptest.NewServer(mux) + defer ts.Close() + + tests := [...]struct { + path string + method string + want string + statusOk bool + }{ + 0: {"/testOne?this=that", "GET", "this=that", true}, + 1: {"/testTwo?foo=bar", "GET", "foo=bar", true}, + 2: {"/testTwo?a=1&b=2&a=3", "GET", "a=1&b=2&a=3", true}, + 3: {"/testTwo?", "GET", "", true}, + 4: {"/testThree?foo", "GET", "foo", true}, + 5: {"/testThree/?foo", "GET", "foo:bar", true}, + 6: {"/testThree?foo", "CONNECT", "foo", true}, + 7: {"/testThree/?foo", "CONNECT", "foo:bar", true}, + + // canonicalization or not + 8: {"/testOne/foo/..?foo", "GET", "foo", true}, + 9: {"/testOne/foo/..?foo", "CONNECT", "404 page not found\n", false}, + } + + for i, tt := range tests { + req, _ := NewRequest(tt.method, ts.URL+tt.path, nil) + res, err := ts.Client().Do(req) + if err != nil { + continue + } + slurp, _ := ioutil.ReadAll(res.Body) + res.Body.Close() + if !tt.statusOk { + if got, want := res.StatusCode, 404; got != want { + t.Errorf("#%d: Status = %d; want = %d", i, got, want) + } + } + if got, want := string(slurp), tt.want; got != want { + t.Errorf("#%d: Body = %q; want = %q", i, got, want) + } + } +} + func BenchmarkServeMux(b *testing.B) { type test struct { @@ -624,12 +686,8 @@ func TestHTTP2WriteDeadlineExtendedOnNewRequest(t *testing.T) { req = req.WithContext(ctx) r, err := c.Do(req) - select { - case <-ctx.Done(): - if ctx.Err() == context.DeadlineExceeded { - t.Fatalf("http2 Get #%d response timed out", i) - } - default: + if ctx.Err() == context.DeadlineExceeded { + t.Fatalf("http2 Get #%d response timed out", i) } if err != nil { t.Fatalf("http2 Get #%d: %v", i, err) @@ -2376,6 +2434,14 @@ func TestTimeoutHandlerEmptyResponse(t *testing.T) { } } +// https://golang.org/issues/22084 +func TestTimeoutHandlerPanicRecovery(t *testing.T) { + wrapper := func(h Handler) Handler { + return TimeoutHandler(h, time.Second, "") + } + testHandlerPanic(t, false, false, wrapper, "intentional death for testing") +} + func TestRedirectBadPath(t *testing.T) { // This used to crash. It's not valid input (bad path), but it // shouldn't crash. @@ -2436,6 +2502,37 @@ func TestRedirect(t *testing.T) { } } +// Test that Content-Type header is set for GET and HEAD requests. +func TestRedirectContentTypeAndBody(t *testing.T) { + var tests = []struct { + method string + wantCT string + wantBody string + }{ + {MethodGet, "text/html; charset=utf-8", "<a href=\"/foo\">Found</a>.\n\n"}, + {MethodHead, "text/html; charset=utf-8", ""}, + {MethodPost, "", ""}, + {MethodDelete, "", ""}, + {"foo", "", ""}, + } + for _, tt := range tests { + req := httptest.NewRequest(tt.method, "http://example.com/qux/", nil) + rec := httptest.NewRecorder() + Redirect(rec, req, "/foo", 302) + if got, want := rec.Header().Get("Content-Type"), tt.wantCT; got != want { + t.Errorf("Redirect(%q) generated Content-Type header %q; want %q", tt.method, got, want) + } + resp := rec.Result() + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + if got, want := string(body), tt.wantBody; got != want { + t.Errorf("Redirect(%q) generated Body %q; want %q", tt.method, got, want) + } + } +} + // TestZeroLengthPostAndResponse exercises an optimization done by the Transport: // when there is no body (either because the method doesn't permit a body, or an // explicit Content-Length of zero is present), then the transport can re-use the @@ -2489,22 +2586,22 @@ func testZeroLengthPostAndResponse(t *testing.T, h2 bool) { } } -func TestHandlerPanicNil_h1(t *testing.T) { testHandlerPanic(t, false, h1Mode, nil) } -func TestHandlerPanicNil_h2(t *testing.T) { testHandlerPanic(t, false, h2Mode, nil) } +func TestHandlerPanicNil_h1(t *testing.T) { testHandlerPanic(t, false, h1Mode, nil, nil) } +func TestHandlerPanicNil_h2(t *testing.T) { testHandlerPanic(t, false, h2Mode, nil, nil) } func TestHandlerPanic_h1(t *testing.T) { - testHandlerPanic(t, false, h1Mode, "intentional death for testing") + testHandlerPanic(t, false, h1Mode, nil, "intentional death for testing") } func TestHandlerPanic_h2(t *testing.T) { - testHandlerPanic(t, false, h2Mode, "intentional death for testing") + testHandlerPanic(t, false, h2Mode, nil, "intentional death for testing") } func TestHandlerPanicWithHijack(t *testing.T) { // Only testing HTTP/1, and our http2 server doesn't support hijacking. - testHandlerPanic(t, true, h1Mode, "intentional death for testing") + testHandlerPanic(t, true, h1Mode, nil, "intentional death for testing") } -func testHandlerPanic(t *testing.T, withHijack, h2 bool, panicValue interface{}) { +func testHandlerPanic(t *testing.T, withHijack, h2 bool, wrapper func(Handler) Handler, panicValue interface{}) { defer afterTest(t) // Unlike the other tests that set the log output to ioutil.Discard // to quiet the output, this test uses a pipe. The pipe serves three @@ -2527,7 +2624,7 @@ func testHandlerPanic(t *testing.T, withHijack, h2 bool, panicValue interface{}) defer log.SetOutput(os.Stderr) defer pw.Close() - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + var handler Handler = HandlerFunc(func(w ResponseWriter, r *Request) { if withHijack { rwc, _, err := w.(Hijacker).Hijack() if err != nil { @@ -2536,7 +2633,11 @@ func testHandlerPanic(t *testing.T, withHijack, h2 bool, panicValue interface{}) defer rwc.Close() } panic(panicValue) - })) + }) + if wrapper != nil { + handler = wrapper(handler) + } + cst := newClientServerTest(t, h2, handler) defer cst.close() // Do a blocking read on the log output pipe so its logging @@ -2691,15 +2792,28 @@ func testRequestLimit(t *testing.T, h2 bool) { req.Header.Set(fmt.Sprintf("header%05d", i), fmt.Sprintf("val%05d", i)) } res, err := cst.c.Do(req) - if err != nil { + if res != nil { + defer res.Body.Close() + } + if h2 { + // In HTTP/2, the result depends on a race. If the client has received the + // server's SETTINGS before RoundTrip starts sending the request, then RoundTrip + // will fail with an error. Otherwise, the client should receive a 431 from the + // server. + if err == nil && res.StatusCode != 431 { + t.Fatalf("expected 431 response status; got: %d %s", res.StatusCode, res.Status) + } + } else { + // In HTTP/1, we expect a 431 from the server. // Some HTTP clients may fail on this undefined behavior (server replying and // closing the connection while the request is still being written), but // we do support it (at least currently), so we expect a response below. - t.Fatalf("Do: %v", err) - } - defer res.Body.Close() - if res.StatusCode != 431 { - t.Fatalf("expected 431 response status; got: %d %s", res.StatusCode, res.Status) + if err != nil { + t.Fatalf("Do: %v", err) + } + if res.StatusCode != 431 { + t.Fatalf("expected 431 response status; got: %d %s", res.StatusCode, res.Status) + } } } @@ -3325,9 +3439,6 @@ func TestHeaderToWire(t *testing.T) { handler: func(rw ResponseWriter, r *Request) { }, check: func(got string) error { - if !strings.Contains(got, "Content-Type: text/plain") { - return errors.New("wrong content-type; want text/plain") - } if !strings.Contains(got, "Content-Length: 0") { return errors.New("want 0 content-length") } @@ -5336,7 +5447,11 @@ func TestServerCloseDeadlock(t *testing.T) { func TestServerKeepAlivesEnabled_h1(t *testing.T) { testServerKeepAlivesEnabled(t, h1Mode) } func TestServerKeepAlivesEnabled_h2(t *testing.T) { testServerKeepAlivesEnabled(t, h2Mode) } func testServerKeepAlivesEnabled(t *testing.T, h2 bool) { - setParallel(t) + if h2 { + restore := ExportSetH2GoawayTimeout(10 * time.Millisecond) + defer restore() + } + // Not parallel: messes with global variable. (http2goAwayTimeout) defer afterTest(t) cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "%v", r.RemoteAddr) diff --git a/libgo/go/net/http/server.go b/libgo/go/net/http/server.go index 2fa8ab23d8a..3fa66601649 100644 --- a/libgo/go/net/http/server.go +++ b/libgo/go/net/http/server.go @@ -179,7 +179,7 @@ type Hijacker interface { // The returned bufio.Reader may contain unprocessed buffered // data from the client. // - // After a call to Hijack, the original Request.Body should + // After a call to Hijack, the original Request.Body must // not be used. Hijack() (net.Conn, *bufio.ReadWriter, error) } @@ -729,6 +729,9 @@ func (cr *connReader) Read(p []byte) (n int, err error) { cr.lock() if cr.inRead { cr.unlock() + if cr.conn.hijacked() { + panic("invalid Body.Read call. After hijacked, the original Request must not be used") + } panic("invalid concurrent Body.Read call") } if cr.hitReadLimit() { @@ -1043,7 +1046,25 @@ func (w *response) Header() Header { // well read them) const maxPostHandlerReadBytes = 256 << 10 +func checkWriteHeaderCode(code int) { + // Issue 22880: require valid WriteHeader status codes. + // For now we only enforce that it's three digits. + // In the future we might block things over 599 (600 and above aren't defined + // at http://httpwg.org/specs/rfc7231.html#status.codes) + // and we might block under 200 (once we have more mature 1xx support). + // But for now any three digits. + // + // We used to send "HTTP/1.1 000 0" on the wire in responses but there's + // no equivalent bogus thing we can realistically send in HTTP/2, + // so we'll consistently panic instead and help people find their bugs + // early. (We can't return an error from WriteHeader even if we wanted to.) + if code < 100 || code > 999 { + panic(fmt.Sprintf("invalid WriteHeader code %v", code)) + } +} + func (w *response) WriteHeader(code int) { + checkWriteHeaderCode(code) if w.conn.hijacked() { w.conn.server.logf("http: response.WriteHeader on hijacked connection") return @@ -1210,7 +1231,7 @@ func (cw *chunkWriter) writeHeader(p []byte) { } } - // Check for a explicit (and valid) Content-Length header. + // Check for an explicit (and valid) Content-Length header. hasCL := w.contentLength != -1 if w.wants10KeepAlive && (isHEAD || hasCL || !bodyAllowedForStatus(w.status)) { @@ -1308,7 +1329,7 @@ func (cw *chunkWriter) writeHeader(p []byte) { if bodyAllowedForStatus(code) { // If no content type, apply sniffing algorithm to body. _, haveType := header["Content-Type"] - if !haveType && !hasTE { + if !haveType && !hasTE && len(p) > 0 { setHeader.contentType = DetectContentType(p) } } else { @@ -1337,7 +1358,7 @@ func (cw *chunkWriter) writeHeader(p []byte) { } else if hasCL { delHeader("Transfer-Encoding") } else if w.req.ProtoAtLeast(1, 1) { - // HTTP/1.1 or greater: Transfer-Encoding has been set to identity, and no + // HTTP/1.1 or greater: Transfer-Encoding has been set to identity, and no // content-length has been provided. The connection must be closed after the // reply is written, and no chunking is to be done. This is the setup // recommended in the Server-Sent Events candidate recommendation 11, @@ -2014,6 +2035,9 @@ func Redirect(w ResponseWriter, r *Request, url string, code int) { } w.Header().Set("Location", hexEscapeNonASCII(url)) + if r.Method == "GET" || r.Method == "HEAD" { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + } w.WriteHeader(code) // RFC 2616 recommends that a short note "SHOULD" be included in the @@ -2105,9 +2129,8 @@ type ServeMux struct { } type muxEntry struct { - explicit bool - h Handler - pattern string + h Handler + pattern string } // NewServeMux allocates and returns a new ServeMux. @@ -2185,6 +2208,31 @@ func (mux *ServeMux) match(path string) (h Handler, pattern string) { return } +// redirectToPathSlash determines if the given path needs appending "/" to it. +// This occurs when a handler for path + "/" was already registered, but +// not for path itself. If the path needs appending to, it creates a new +// URL, setting the path to u.Path + "/" and returning true to indicate so. +func (mux *ServeMux) redirectToPathSlash(path string, u *url.URL) (*url.URL, bool) { + if !mux.shouldRedirect(path) { + return u, false + } + path = path + "/" + u = &url.URL{Path: path, RawQuery: u.RawQuery} + return u, true +} + +// shouldRedirect reports whether the given path should be redirected to +// path+"/". This should happen if a handler is registered for path+"/" but +// not path -- see comments at ServeMux. +func (mux *ServeMux) shouldRedirect(path string) bool { + if _, exist := mux.m[path]; exist { + return false + } + n := len(path) + _, exist := mux.m[path+"/"] + return n > 0 && path[n-1] != '/' && exist +} + // Handler returns the handler to use for the given request, // consulting r.Method, r.Host, and r.URL.Path. It always returns // a non-nil handler. If the path is not in its canonical form, the @@ -2204,6 +2252,13 @@ func (mux *ServeMux) Handler(r *Request) (h Handler, pattern string) { // CONNECT requests are not canonicalized. if r.Method == "CONNECT" { + // If r.URL.Path is /tree and its handler is not registered, + // the /tree -> /tree/ redirect applies to CONNECT requests + // but the path canonicalization does not. + if u, ok := mux.redirectToPathSlash(r.URL.Path, r.URL); ok { + return RedirectHandler(u.String(), StatusMovedPermanently), u.Path + } + return mux.handler(r.Host, r.URL.Path) } @@ -2211,6 +2266,13 @@ func (mux *ServeMux) Handler(r *Request) (h Handler, pattern string) { // before passing to mux.handler. host := stripHostPort(r.Host) path := cleanPath(r.URL.Path) + + // If the given path is /tree and its handler is not registered, + // redirect for /tree/. + if u, ok := mux.redirectToPathSlash(path, r.URL); ok { + return RedirectHandler(u.String(), StatusMovedPermanently), u.Path + } + if path != r.URL.Path { _, pattern = mux.handler(host, path) url := *r.URL @@ -2261,40 +2323,23 @@ func (mux *ServeMux) Handle(pattern string, handler Handler) { defer mux.mu.Unlock() if pattern == "" { - panic("http: invalid pattern " + pattern) + panic("http: invalid pattern") } if handler == nil { panic("http: nil handler") } - if mux.m[pattern].explicit { + if _, exist := mux.m[pattern]; exist { panic("http: multiple registrations for " + pattern) } if mux.m == nil { mux.m = make(map[string]muxEntry) } - mux.m[pattern] = muxEntry{explicit: true, h: handler, pattern: pattern} + mux.m[pattern] = muxEntry{h: handler, pattern: pattern} if pattern[0] != '/' { mux.hosts = true } - - // Helpful behavior: - // If pattern is /tree/, insert an implicit permanent redirect for /tree. - // It can be overridden by an explicit registration. - n := len(pattern) - if n > 0 && pattern[n-1] == '/' && !mux.m[pattern[0:n-1]].explicit { - // If pattern contains a host name, strip it and use remaining - // path for redirect. - path := pattern - if pattern[0] != '/' { - // In pattern, at least the last character is a '/', so - // strings.Index can't be -1. - path = pattern[strings.Index(pattern, "/"):] - } - url := &url.URL{Path: path} - mux.m[pattern[0:n-1]] = muxEntry{h: RedirectHandler(url.String(), StatusMovedPermanently), pattern: pattern} - } } // HandleFunc registers the handler function for the given pattern. @@ -2323,7 +2368,7 @@ func Serve(l net.Listener, handler Handler) error { return srv.Serve(l) } -// Serve accepts incoming HTTPS connections on the listener l, +// ServeTLS accepts incoming HTTPS connections on the listener l, // creating a new service goroutine for each. The service goroutines // read requests and then call handler to reply to them. // @@ -2396,9 +2441,9 @@ type Server struct { ConnState func(net.Conn, ConnState) // ErrorLog specifies an optional logger for errors accepting - // connections and unexpected behavior from handlers. - // If nil, logging goes to os.Stderr via the log package's - // standard logger. + // connections, unexpected behavior from handlers, and + // underlying FileSystem errors. + // If nil, logging is done via the log package's standard logger. ErrorLog *log.Logger disableKeepAlives int32 // accessed atomically. @@ -2483,7 +2528,8 @@ var shutdownPollInterval = 500 * time.Millisecond // Shutdown does not attempt to close nor wait for hijacked // connections such as WebSockets. The caller of Shutdown should // separately notify such long-lived connections of shutdown and wait -// for them to close, if desired. +// for them to close, if desired. See RegisterOnShutdown for a way to +// register shutdown notification functions. func (srv *Server) Shutdown(ctx context.Context) error { atomic.AddInt32(&srv.inShutdown, 1) defer atomic.AddInt32(&srv.inShutdown, -1) @@ -2732,7 +2778,7 @@ func (srv *Server) Serve(l net.Listener) error { // server's certificate, any intermediates, and the CA's certificate. // // For HTTP/2 support, srv.TLSConfig should be initialized to the -// provided listener's TLS Config before calling Serve. If +// provided listener's TLS Config before calling ServeTLS. If // srv.TLSConfig is non-nil and doesn't include the string "h2" in // Config.NextProtos, HTTP/2 support is not enabled. // @@ -2849,6 +2895,18 @@ func (s *Server) logf(format string, args ...interface{}) { } } +// logf prints to the ErrorLog of the *Server associated with request r +// via ServerContextKey. If there's no associated server, or if ErrorLog +// is nil, logging is done via the log package's standard logger. +func logf(r *Request, format string, args ...interface{}) { + s, _ := r.Context().Value(ServerContextKey).(*Server) + if s != nil && s.ErrorLog != nil { + s.ErrorLog.Printf(format, args...) + } else { + log.Printf(format, args...) + } +} + // ListenAndServe listens on the TCP network address addr // and then calls Serve with handler to handle requests // on incoming connections. @@ -2940,6 +2998,8 @@ func (srv *Server) ListenAndServeTLS(certFile, keyFile string) error { return err } + defer ln.Close() + return srv.ServeTLS(tcpKeepAliveListener{ln.(*net.TCPListener)}, certFile, keyFile) } @@ -3015,9 +3075,9 @@ type timeoutHandler struct { body string dt time.Duration - // When set, no timer will be created and this channel will + // When set, no context will be created and this context will // be used instead. - testTimeout <-chan time.Time + testContext context.Context } func (h *timeoutHandler) errorBody() string { @@ -3028,22 +3088,31 @@ func (h *timeoutHandler) errorBody() string { } func (h *timeoutHandler) ServeHTTP(w ResponseWriter, r *Request) { - var t *time.Timer - timeout := h.testTimeout - if timeout == nil { - t = time.NewTimer(h.dt) - timeout = t.C + ctx := h.testContext + if ctx == nil { + var cancelCtx context.CancelFunc + ctx, cancelCtx = context.WithTimeout(r.Context(), h.dt) + defer cancelCtx() } + r = r.WithContext(ctx) done := make(chan struct{}) tw := &timeoutWriter{ w: w, h: make(Header), } + panicChan := make(chan interface{}, 1) go func() { + defer func() { + if p := recover(); p != nil { + panicChan <- p + } + }() h.handler.ServeHTTP(tw, r) close(done) }() select { + case p := <-panicChan: + panic(p) case <-done: tw.mu.Lock() defer tw.mu.Unlock() @@ -3056,10 +3125,7 @@ func (h *timeoutHandler) ServeHTTP(w ResponseWriter, r *Request) { } w.WriteHeader(tw.code) w.Write(tw.wbuf.Bytes()) - if t != nil { - t.Stop() - } - case <-timeout: + case <-ctx.Done(): tw.mu.Lock() defer tw.mu.Unlock() w.WriteHeader(StatusServiceUnavailable) @@ -3095,6 +3161,7 @@ func (tw *timeoutWriter) Write(p []byte) (int, error) { } func (tw *timeoutWriter) WriteHeader(code int) { + checkWriteHeaderCode(code) tw.mu.Lock() defer tw.mu.Unlock() if tw.timedOut || tw.wroteHeader { @@ -3116,10 +3183,10 @@ type tcpKeepAliveListener struct { *net.TCPListener } -func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { +func (ln tcpKeepAliveListener) Accept() (net.Conn, error) { tc, err := ln.AcceptTCP() if err != nil { - return + return nil, err } tc.SetKeepAlive(true) tc.SetKeepAlivePeriod(3 * time.Minute) diff --git a/libgo/go/net/http/sniff.go b/libgo/go/net/http/sniff.go index ecc65e4de64..365a36c79ef 100644 --- a/libgo/go/net/http/sniff.go +++ b/libgo/go/net/http/sniff.go @@ -91,6 +91,7 @@ var sniffSignatures = []sniffSig{ ct: "image/webp", }, &exactSig{[]byte("\x00\x00\x01\x00"), "image/vnd.microsoft.icon"}, + &maskedSig{ mask: []byte("\xFF\xFF\xFF\xFF\x00\x00\x00\x00\xFF\xFF\xFF\xFF"), pat: []byte("RIFF\x00\x00\x00\x00WAVE"), @@ -126,6 +127,20 @@ var sniffSignatures = []sniffSig{ pat: []byte("RIFF\x00\x00\x00\x00AVI "), ct: "video/avi", }, + + // Fonts + &maskedSig{ + // 34 NULL bytes followed by the string "LP" + pat: []byte("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x4C\x50"), + // 34 NULL bytes followed by \xF\xF + mask: []byte("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xFF\xFF"), + ct: "application/vnd.ms-fontobject", + }, + &exactSig{[]byte("\x00\x01\x00\x00"), "application/font-ttf"}, + &exactSig{[]byte("OTTO"), "application/font-off"}, + &exactSig{[]byte("ttcf"), "application/font-cff"}, + &exactSig{[]byte("wOFF"), "application/font-woff"}, + &exactSig{[]byte("\x1A\x45\xDF\xA3"), "video/webm"}, &exactSig{[]byte("\x52\x61\x72\x20\x1A\x07\x00"), "application/x-rar-compressed"}, &exactSig{[]byte("\x50\x4B\x03\x04"), "application/zip"}, diff --git a/libgo/go/net/http/sniff_test.go b/libgo/go/net/http/sniff_test.go index 24f1298e5d9..bf1f6be41b1 100644 --- a/libgo/go/net/http/sniff_test.go +++ b/libgo/go/net/http/sniff_test.go @@ -55,6 +55,17 @@ var sniffTests = []struct { {"MP4 video", []byte("\x00\x00\x00\x18ftypmp42\x00\x00\x00\x00mp42isom<\x06t\xbfmdat"), "video/mp4"}, {"AVI video #1", []byte("RIFF,O\n\x00AVI LISTÀ"), "video/avi"}, {"AVI video #2", []byte("RIFF,\n\x00\x00AVI LISTÀ"), "video/avi"}, + + // Font types. + // {"MS.FontObject", []byte("\x00\x00")}, + {"TTF sample I", []byte("\x00\x01\x00\x00\x00\x17\x01\x00\x00\x04\x01\x60\x4f"), "application/font-ttf"}, + {"TTF sample II", []byte("\x00\x01\x00\x00\x00\x0e\x00\x80\x00\x03\x00\x60\x46"), "application/font-ttf"}, + + {"OTTO sample I", []byte("\x4f\x54\x54\x4f\x00\x0e\x00\x80\x00\x03\x00\x60\x42\x41\x53\x45"), "application/font-off"}, + + {"woff sample I", []byte("\x77\x4f\x46\x46\x00\x01\x00\x00\x00\x00\x30\x54\x00\x0d\x00\x00"), "application/font-woff"}, + // Woff2 is not yet recognized, change this test once mime-sniff working group adds woff2 + {"woff2 not recognized", []byte("\x77\x4f\x46\x32\x00\x01\x00\x00\x00"), "application/octet-stream"}, } func TestDetectContentType(t *testing.T) { @@ -88,8 +99,17 @@ func testServerContentType(t *testing.T, h2 bool) { t.Errorf("%v: %v", tt.desc, err) continue } - if ct := resp.Header.Get("Content-Type"); ct != tt.contentType { - t.Errorf("%v: Content-Type = %q, want %q", tt.desc, ct, tt.contentType) + // DetectContentType is defined to return + // text/plain; charset=utf-8 for an empty body, + // but as of Go 1.10 the HTTP server has been changed + // to return no content-type at all for an empty body. + // Adjust the expectation here. + wantContentType := tt.contentType + if len(tt.data) == 0 { + wantContentType = "" + } + if ct := resp.Header.Get("Content-Type"); ct != wantContentType { + t.Errorf("%v: Content-Type = %q, want %q", tt.desc, ct, wantContentType) } data, err := ioutil.ReadAll(resp.Body) if err != nil { diff --git a/libgo/go/net/http/transfer.go b/libgo/go/net/http/transfer.go index 8faff2d74a6..a400a6abb1f 100644 --- a/libgo/go/net/http/transfer.go +++ b/libgo/go/net/http/transfer.go @@ -497,7 +497,7 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) { // or close connection when finished, since multipart is not supported yet switch { case chunked(t.TransferEncoding): - if noResponseBodyExpected(t.RequestMethod) { + if noResponseBodyExpected(t.RequestMethod) || !bodyAllowedForStatus(t.StatusCode) { t.Body = NoBody } else { t.Body = &body{src: internal.NewChunkedReader(r), hdr: msg, r: r, closing: t.Close} @@ -663,9 +663,8 @@ func fixLength(isResponse bool, status int, requestMethod string, header Header, return -1, err } return n, nil - } else { - header.Del("Content-Length") } + header.Del("Content-Length") if isRequest { // RFC 2616 neither explicitly permits nor forbids an diff --git a/libgo/go/net/http/transport.go b/libgo/go/net/http/transport.go index 6a89392a996..45e3fd2eba7 100644 --- a/libgo/go/net/http/transport.go +++ b/libgo/go/net/http/transport.go @@ -1016,6 +1016,69 @@ func (d oneConnDialer) Dial(network, addr string) (net.Conn, error) { } } +// The connect method and the transport can both specify a TLS +// Host name. The transport's name takes precedence if present. +func chooseTLSHost(cm connectMethod, t *Transport) string { + tlsHost := "" + if t.TLSClientConfig != nil { + tlsHost = t.TLSClientConfig.ServerName + } + if tlsHost == "" { + tlsHost = cm.tlsHost() + } + return tlsHost +} + +// Add TLS to a persistent connection, i.e. negotiate a TLS session. If pconn is already a TLS +// tunnel, this function establishes a nested TLS session inside the encrypted channel. +// The remote endpoint's name may be overridden by TLSClientConfig.ServerName. +func (pconn *persistConn) addTLS(name string, trace *httptrace.ClientTrace) error { + // Initiate TLS and check remote host name against certificate. + cfg := cloneTLSConfig(pconn.t.TLSClientConfig) + if cfg.ServerName == "" { + cfg.ServerName = name + } + plainConn := pconn.conn + tlsConn := tls.Client(plainConn, cfg) + errc := make(chan error, 2) + var timer *time.Timer // for canceling TLS handshake + if d := pconn.t.TLSHandshakeTimeout; d != 0 { + timer = time.AfterFunc(d, func() { + errc <- tlsHandshakeTimeoutError{} + }) + } + go func() { + if trace != nil && trace.TLSHandshakeStart != nil { + trace.TLSHandshakeStart() + } + err := tlsConn.Handshake() + if timer != nil { + timer.Stop() + } + errc <- err + }() + if err := <-errc; err != nil { + plainConn.Close() + if trace != nil && trace.TLSHandshakeDone != nil { + trace.TLSHandshakeDone(tls.ConnectionState{}, err) + } + return err + } + if !cfg.InsecureSkipVerify { + if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil { + plainConn.Close() + return err + } + } + cs := tlsConn.ConnectionState() + if trace != nil && trace.TLSHandshakeDone != nil { + trace.TLSHandshakeDone(cs, nil) + } + pconn.tlsState = &cs + pconn.conn = tlsConn + return nil +} + func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistConn, error) { pconn := &persistConn{ t: t, @@ -1027,15 +1090,21 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon writeLoopDone: make(chan struct{}), } trace := httptrace.ContextClientTrace(ctx) - tlsDial := t.DialTLS != nil && cm.targetScheme == "https" && cm.proxyURL == nil - if tlsDial { + wrapErr := func(err error) error { + if cm.proxyURL != nil { + // Return a typed error, per Issue 16997 + return &net.OpError{Op: "proxyconnect", Net: "tcp", Err: err} + } + return err + } + if cm.scheme() == "https" && t.DialTLS != nil { var err error pconn.conn, err = t.DialTLS("tcp", cm.addr()) if err != nil { - return nil, err + return nil, wrapErr(err) } if pconn.conn == nil { - return nil, errors.New("net/http: Transport.DialTLS returned (nil, nil)") + return nil, wrapErr(errors.New("net/http: Transport.DialTLS returned (nil, nil)")) } if tc, ok := pconn.conn.(*tls.Conn); ok { // Handshake here, in case DialTLS didn't. TLSNextProto below @@ -1059,13 +1128,18 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon } else { conn, err := t.dial(ctx, "tcp", cm.addr()) if err != nil { - if cm.proxyURL != nil { - // Return a typed error, per Issue 16997: - err = &net.OpError{Op: "proxyconnect", Net: "tcp", Err: err} - } - return nil, err + return nil, wrapErr(err) } pconn.conn = conn + if cm.scheme() == "https" { + var firstTLSHost string + if firstTLSHost, _, err = net.SplitHostPort(cm.addr()); err != nil { + return nil, wrapErr(err) + } + if err = pconn.addTLS(firstTLSHost, trace); err != nil { + return nil, wrapErr(err) + } + } } // Proxy setup. @@ -1125,54 +1199,17 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon if resp.StatusCode != 200 { f := strings.SplitN(resp.Status, " ", 2) conn.Close() + if len(f) < 2 { + return nil, errors.New("unknown status code") + } return nil, errors.New(f[1]) } } - if cm.targetScheme == "https" && !tlsDial { - // Initiate TLS and check remote host name against certificate. - cfg := cloneTLSConfig(t.TLSClientConfig) - if cfg.ServerName == "" { - cfg.ServerName = cm.tlsHost() - } - plainConn := pconn.conn - tlsConn := tls.Client(plainConn, cfg) - errc := make(chan error, 2) - var timer *time.Timer // for canceling TLS handshake - if d := t.TLSHandshakeTimeout; d != 0 { - timer = time.AfterFunc(d, func() { - errc <- tlsHandshakeTimeoutError{} - }) - } - go func() { - if trace != nil && trace.TLSHandshakeStart != nil { - trace.TLSHandshakeStart() - } - err := tlsConn.Handshake() - if timer != nil { - timer.Stop() - } - errc <- err - }() - if err := <-errc; err != nil { - plainConn.Close() - if trace != nil && trace.TLSHandshakeDone != nil { - trace.TLSHandshakeDone(tls.ConnectionState{}, err) - } + if cm.proxyURL != nil && cm.targetScheme == "https" { + if err := pconn.addTLS(cm.tlsHost(), trace); err != nil { return nil, err } - if !cfg.InsecureSkipVerify { - if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil { - plainConn.Close() - return nil, err - } - } - cs := tlsConn.ConnectionState() - if trace != nil && trace.TLSHandshakeDone != nil { - trace.TLSHandshakeDone(cs, nil) - } - pconn.tlsState = &cs - pconn.conn = tlsConn } if s := pconn.tlsState; s != nil && s.NegotiatedProtocolIsMutual && s.NegotiatedProtocol != "" { @@ -1224,8 +1261,8 @@ func useProxy(addr string) bool { } } - no_proxy := noProxyEnv.Get() - if no_proxy == "*" { + noProxy := noProxyEnv.Get() + if noProxy == "*" { return false } @@ -1234,7 +1271,7 @@ func useProxy(addr string) bool { addr = addr[:strings.LastIndex(addr, ":")] } - for _, p := range strings.Split(no_proxy, ",") { + for _, p := range strings.Split(noProxy, ",") { p = strings.ToLower(strings.TrimSpace(p)) if len(p) == 0 { continue @@ -1266,21 +1303,24 @@ func useProxy(addr string) bool { // // A connect method may be of the following types: // -// Cache key form Description -// ----------------- ------------------------- -// |http|foo.com http directly to server, no proxy -// |https|foo.com https directly to server, no proxy -// http://proxy.com|https|foo.com http to proxy, then CONNECT to foo.com -// http://proxy.com|http http to proxy, http to anywhere after that -// socks5://proxy.com|http|foo.com socks5 to proxy, then http to foo.com -// socks5://proxy.com|https|foo.com socks5 to proxy, then https to foo.com -// -// Note: no support to https to the proxy yet. +// Cache key form Description +// ----------------- ------------------------- +// |http|foo.com http directly to server, no proxy +// |https|foo.com https directly to server, no proxy +// http://proxy.com|https|foo.com http to proxy, then CONNECT to foo.com +// http://proxy.com|http http to proxy, http to anywhere after that +// socks5://proxy.com|http|foo.com socks5 to proxy, then http to foo.com +// socks5://proxy.com|https|foo.com socks5 to proxy, then https to foo.com +// https://proxy.com|https|foo.com https to proxy, then CONNECT to foo.com +// https://proxy.com|http https to proxy, http to anywhere after that // type connectMethod struct { proxyURL *url.URL // nil for no proxy, else full proxy URL targetScheme string // "http" or "https" - targetAddr string // Not used if http proxy + http targetScheme (4th example in table) + // If proxyURL specifies an http or https proxy, and targetScheme is http (not https), + // then targetAddr is not included in the connect method key, because the socket can + // be reused for different targetAddr values. + targetAddr string } func (cm *connectMethod) key() connectMethodKey { @@ -1288,7 +1328,7 @@ func (cm *connectMethod) key() connectMethodKey { targetAddr := cm.targetAddr if cm.proxyURL != nil { proxyStr = cm.proxyURL.String() - if strings.HasPrefix(cm.proxyURL.Scheme, "http") && cm.targetScheme == "http" { + if (cm.proxyURL.Scheme == "http" || cm.proxyURL.Scheme == "https") && cm.targetScheme == "http" { targetAddr = "" } } @@ -1299,6 +1339,14 @@ func (cm *connectMethod) key() connectMethodKey { } } +// scheme returns the first hop scheme: http, https, or socks5 +func (cm *connectMethod) scheme() string { + if cm.proxyURL != nil { + return cm.proxyURL.Scheme + } + return cm.targetScheme +} + // addr returns the first hop "host:port" to which we need to TCP connect. func (cm *connectMethod) addr() string { if cm.proxyURL != nil { @@ -1616,6 +1664,7 @@ func (pc *persistConn) readLoop() { body: resp.Body, earlyCloseFn: func() error { waitForBodyRead <- false + <-eofc // will be closed by deferred call at the end of the function return nil }, @@ -2021,8 +2070,8 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err // a t.Logf func. See export_test.go's Request.WithT method. type tLogKey struct{} -func (r *transportRequest) logf(format string, args ...interface{}) { - if logf, ok := r.Request.Context().Value(tLogKey{}).(func(string, ...interface{})); ok { +func (tr *transportRequest) logf(format string, args ...interface{}) { + if logf, ok := tr.Request.Context().Value(tLogKey{}).(func(string, ...interface{})); ok { logf(time.Now().Format(time.RFC3339Nano)+": "+format, args...) } } diff --git a/libgo/go/net/http/transport_test.go b/libgo/go/net/http/transport_test.go index 27b55dca2f3..55880774256 100644 --- a/libgo/go/net/http/transport_test.go +++ b/libgo/go/net/http/transport_test.go @@ -124,6 +124,34 @@ func (tcs *testConnSet) check(t *testing.T) { } } +func TestReuseRequest(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.Write([]byte("{}")) + })) + defer ts.Close() + + c := ts.Client() + req, _ := NewRequest("GET", ts.URL, nil) + res, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + err = res.Body.Close() + if err != nil { + t.Fatal(err) + } + + res, err = c.Do(req) + if err != nil { + t.Fatal(err) + } + err = res.Body.Close() + if err != nil { + t.Fatal(err) + } +} + // Two subsequent requests and verify their response is the same. // The response from the server is our own IP:port func TestTransportKeepAlives(t *testing.T) { @@ -933,14 +961,10 @@ func TestTransportExpect100Continue(t *testing.T) { func TestSocks5Proxy(t *testing.T) { defer afterTest(t) ch := make(chan string, 1) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { - ch <- "real server" - })) - defer ts.Close() l := newLocalListener(t) defer l.Close() - go func() { - defer close(ch) + defer close(ch) + proxy := func(t *testing.T) { s, err := l.Accept() if err != nil { t.Errorf("socks5 proxy Accept(): %v", err) @@ -975,7 +999,8 @@ func TestSocks5Proxy(t *testing.T) { case 4: ipLen = 16 default: - t.Fatalf("socks5 proxy second read: unexpected address type %v", buf[4]) + t.Errorf("socks5 proxy second read: unexpected address type %v", buf[4]) + return } if _, err := io.ReadFull(s, buf[4:ipLen+6]); err != nil { t.Errorf("socks5 proxy address read: %v", err) @@ -988,71 +1013,196 @@ func TestSocks5Proxy(t *testing.T) { t.Errorf("socks5 proxy connect write: %v", err) return } - done := make(chan struct{}) - srv := &Server{Handler: HandlerFunc(func(w ResponseWriter, r *Request) { - done <- struct{}{} - })} - srv.Serve(&oneConnListener{conn: s}) - <-done - srv.Shutdown(context.Background()) ch <- fmt.Sprintf("proxy for %s:%d", ip, port) - }() - pu, err := url.Parse("socks5://" + l.Addr().String()) - if err != nil { - t.Fatal(err) - } - c := ts.Client() - c.Transport.(*Transport).Proxy = ProxyURL(pu) - if _, err := c.Head(ts.URL); err != nil { - t.Error(err) - } - var got string - select { - case got = <-ch: - case <-time.After(5 * time.Second): - t.Fatal("timeout connecting to socks5 proxy") + // Implement proxying. + targetHost := net.JoinHostPort(ip.String(), strconv.Itoa(int(port))) + targetConn, err := net.Dial("tcp", targetHost) + if err != nil { + t.Errorf("net.Dial failed") + return + } + go io.Copy(targetConn, s) + io.Copy(s, targetConn) // Wait for the client to close the socket. + targetConn.Close() } - tsu, err := url.Parse(ts.URL) + + pu, err := url.Parse("socks5://" + l.Addr().String()) if err != nil { t.Fatal(err) } - want := "proxy for " + tsu.Host - if got != want { - t.Errorf("got %q, want %q", got, want) + + sentinelHeader := "X-Sentinel" + sentinelValue := "12345" + h := HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set(sentinelHeader, sentinelValue) + }) + for _, useTLS := range []bool{false, true} { + t.Run(fmt.Sprintf("useTLS=%v", useTLS), func(t *testing.T) { + var ts *httptest.Server + if useTLS { + ts = httptest.NewTLSServer(h) + } else { + ts = httptest.NewServer(h) + } + go proxy(t) + c := ts.Client() + c.Transport.(*Transport).Proxy = ProxyURL(pu) + r, err := c.Head(ts.URL) + if err != nil { + t.Fatal(err) + } + if r.Header.Get(sentinelHeader) != sentinelValue { + t.Errorf("Failed to retrieve sentinel value") + } + var got string + select { + case got = <-ch: + case <-time.After(5 * time.Second): + t.Fatal("timeout connecting to socks5 proxy") + } + ts.Close() + tsu, err := url.Parse(ts.URL) + if err != nil { + t.Fatal(err) + } + want := "proxy for " + tsu.Host + if got != want { + t.Errorf("got %q, want %q", got, want) + } + }) } } func TestTransportProxy(t *testing.T) { defer afterTest(t) - ch := make(chan string, 1) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { - ch <- "real server" - })) - defer ts.Close() - proxy := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { - ch <- "proxy for " + r.URL.String() - })) - defer proxy.Close() + testCases := []struct{ httpsSite, httpsProxy bool }{ + {false, false}, + {false, true}, + {true, false}, + {true, true}, + } + for _, testCase := range testCases { + httpsSite := testCase.httpsSite + httpsProxy := testCase.httpsProxy + t.Run(fmt.Sprintf("httpsSite=%v, httpsProxy=%v", httpsSite, httpsProxy), func(t *testing.T) { + siteCh := make(chan *Request, 1) + h1 := HandlerFunc(func(w ResponseWriter, r *Request) { + siteCh <- r + }) + proxyCh := make(chan *Request, 1) + h2 := HandlerFunc(func(w ResponseWriter, r *Request) { + proxyCh <- r + // Implement an entire CONNECT proxy + if r.Method == "CONNECT" { + hijacker, ok := w.(Hijacker) + if !ok { + t.Errorf("hijack not allowed") + return + } + clientConn, _, err := hijacker.Hijack() + if err != nil { + t.Errorf("hijacking failed") + return + } + res := &Response{ + StatusCode: StatusOK, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: make(Header), + } + + targetConn, err := net.Dial("tcp", r.URL.Host) + if err != nil { + t.Errorf("net.Dial(%q) failed: %v", r.URL.Host, err) + return + } + + if err := res.Write(clientConn); err != nil { + t.Errorf("Writing 200 OK failed: %v", err) + return + } + + go io.Copy(targetConn, clientConn) + go func() { + io.Copy(clientConn, targetConn) + targetConn.Close() + }() + } + }) + var ts *httptest.Server + if httpsSite { + ts = httptest.NewTLSServer(h1) + } else { + ts = httptest.NewServer(h1) + } + var proxy *httptest.Server + if httpsProxy { + proxy = httptest.NewTLSServer(h2) + } else { + proxy = httptest.NewServer(h2) + } - pu, err := url.Parse(proxy.URL) - if err != nil { - t.Fatal(err) - } - c := ts.Client() - c.Transport.(*Transport).Proxy = ProxyURL(pu) - if _, err := c.Head(ts.URL); err != nil { - t.Error(err) - } - var got string - select { - case got = <-ch: - case <-time.After(5 * time.Second): - t.Fatal("timeout connecting to http proxy") - } - want := "proxy for " + ts.URL + "/" - if got != want { - t.Errorf("got %q, want %q", got, want) + pu, err := url.Parse(proxy.URL) + if err != nil { + t.Fatal(err) + } + + // If neither server is HTTPS or both are, then c may be derived from either. + // If only one server is HTTPS, c must be derived from that server in order + // to ensure that it is configured to use the fake root CA from testcert.go. + c := proxy.Client() + if httpsSite { + c = ts.Client() + } + + c.Transport.(*Transport).Proxy = ProxyURL(pu) + if _, err := c.Head(ts.URL); err != nil { + t.Error(err) + } + var got *Request + select { + case got = <-proxyCh: + case <-time.After(5 * time.Second): + t.Fatal("timeout connecting to http proxy") + } + c.Transport.(*Transport).CloseIdleConnections() + ts.Close() + proxy.Close() + if httpsSite { + // First message should be a CONNECT, asking for a socket to the real server, + if got.Method != "CONNECT" { + t.Errorf("Wrong method for secure proxying: %q", got.Method) + } + gotHost := got.URL.Host + pu, err := url.Parse(ts.URL) + if err != nil { + t.Fatal("Invalid site URL") + } + if wantHost := pu.Host; gotHost != wantHost { + t.Errorf("Got CONNECT host %q, want %q", gotHost, wantHost) + } + + // The next message on the channel should be from the site's server. + next := <-siteCh + if next.Method != "HEAD" { + t.Errorf("Wrong method at destination: %s", next.Method) + } + if nextURL := next.URL.String(); nextURL != "/" { + t.Errorf("Wrong URL at destination: %s", nextURL) + } + } else { + if got.Method != "HEAD" { + t.Errorf("Wrong method for destination: %q", got.Method) + } + gotURL := got.URL.String() + wantURL := ts.URL + "/" + if gotURL != wantURL { + t.Errorf("Got URL %q, want %q", gotURL, wantURL) + } + } + }) } } @@ -4118,3 +4268,100 @@ var rgz = []byte{ 0x00, 0x00, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00, 0x00, 0x00, } + +// Ensure that a missing status doesn't make the server panic +// See Issue https://golang.org/issues/21701 +func TestMissingStatusNoPanic(t *testing.T) { + t.Parallel() + + const want = "unknown status code" + + ln := newLocalListener(t) + addr := ln.Addr().String() + shutdown := make(chan bool, 1) + done := make(chan bool) + fullAddrURL := fmt.Sprintf("http://%s", addr) + raw := "HTTP/1.1 400\r\n" + + "Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" + + "Content-Type: text/html; charset=utf-8\r\n" + + "Content-Length: 10\r\n" + + "Last-Modified: Wed, 30 Aug 2017 19:02:02 GMT\r\n" + + "Vary: Accept-Encoding\r\n\r\n" + + "Aloha Olaa" + + go func() { + defer func() { + ln.Close() + close(done) + }() + + conn, _ := ln.Accept() + if conn != nil { + io.WriteString(conn, raw) + ioutil.ReadAll(conn) + conn.Close() + } + }() + + proxyURL, err := url.Parse(fullAddrURL) + if err != nil { + t.Fatalf("proxyURL: %v", err) + } + + tr := &Transport{Proxy: ProxyURL(proxyURL)} + + req, _ := NewRequest("GET", "https://golang.org/", nil) + res, err, panicked := doFetchCheckPanic(tr, req) + if panicked { + t.Error("panicked, expecting an error") + } + if res != nil && res.Body != nil { + io.Copy(ioutil.Discard, res.Body) + res.Body.Close() + } + + if err == nil || !strings.Contains(err.Error(), want) { + t.Errorf("got=%v want=%q", err, want) + } + + close(shutdown) + <-done +} + +func doFetchCheckPanic(tr *Transport, req *Request) (res *Response, err error, panicked bool) { + defer func() { + if r := recover(); r != nil { + panicked = true + } + }() + res, err = tr.RoundTrip(req) + return +} + +// Issue 22330: do not allow the response body to be read when the status code +// forbids a response body. +func TestNoBodyOnChunked304Response(t *testing.T) { + defer afterTest(t) + cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + conn, buf, _ := w.(Hijacker).Hijack() + buf.Write([]byte("HTTP/1.1 304 NOT MODIFIED\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\n")) + buf.Flush() + conn.Close() + })) + defer cst.close() + + // Our test server above is sending back bogus data after the + // response (the "0\r\n\r\n" part), which causes the Transport + // code to log spam. Disable keep-alives so we never even try + // to reuse the connection. + cst.tr.DisableKeepAlives = true + + res, err := cst.c.Get(cst.ts.URL) + if err != nil { + t.Fatal(err) + } + + if res.Body != NoBody { + t.Errorf("Unexpected body on 304 response") + } +} diff --git a/libgo/go/net/internal/socktest/sys_windows.go b/libgo/go/net/internal/socktest/sys_windows.go index 2e3d2bc7fce..8c1c862f33c 100644 --- a/libgo/go/net/internal/socktest/sys_windows.go +++ b/libgo/go/net/internal/socktest/sys_windows.go @@ -4,7 +4,10 @@ package socktest -import "syscall" +import ( + "internal/syscall/windows" + "syscall" +) // Socket wraps syscall.Socket. func (sw *Switch) Socket(family, sotype, proto int) (s syscall.Handle, err error) { @@ -38,6 +41,38 @@ func (sw *Switch) Socket(family, sotype, proto int) (s syscall.Handle, err error return s, nil } +// WSASocket wraps syscall.WSASocket. +func (sw *Switch) WSASocket(family, sotype, proto int32, protinfo *syscall.WSAProtocolInfo, group uint32, flags uint32) (s syscall.Handle, err error) { + sw.once.Do(sw.init) + + so := &Status{Cookie: cookie(int(family), int(sotype), int(proto))} + sw.fmu.RLock() + f, _ := sw.fltab[FilterSocket] + sw.fmu.RUnlock() + + af, err := f.apply(so) + if err != nil { + return syscall.InvalidHandle, err + } + s, so.Err = windows.WSASocket(family, sotype, proto, protinfo, group, flags) + if err = af.apply(so); err != nil { + if so.Err == nil { + syscall.Closesocket(s) + } + return syscall.InvalidHandle, err + } + + sw.smu.Lock() + defer sw.smu.Unlock() + if so.Err != nil { + sw.stats.getLocked(so.Cookie).OpenFailed++ + return syscall.InvalidHandle, so.Err + } + nso := sw.addLocked(s, int(family), int(sotype), int(proto)) + sw.stats.getLocked(nso.Cookie).Opened++ + return s, nil +} + // Closesocket wraps syscall.Closesocket. func (sw *Switch) Closesocket(s syscall.Handle) (err error) { so := sw.sockso(s) diff --git a/libgo/go/net/iprawsock.go b/libgo/go/net/iprawsock.go index c4b54f00c4e..72cbc394337 100644 --- a/libgo/go/net/iprawsock.go +++ b/libgo/go/net/iprawsock.go @@ -21,7 +21,7 @@ import ( // change the behavior of these methods; use Read or ReadMsgIP // instead. -// BUG(mikio): On NaCl, Plan 9 and Windows, the ReadMsgIP and +// BUG(mikio): On NaCl and Plan 9, the ReadMsgIP and // WriteMsgIP methods of IPConn are not implemented. // BUG(mikio): On Windows, the File method of IPConn is not diff --git a/libgo/go/net/listen_test.go b/libgo/go/net/listen_test.go index 21ad4462f68..96624f98ce5 100644 --- a/libgo/go/net/listen_test.go +++ b/libgo/go/net/listen_test.go @@ -13,6 +13,7 @@ import ( "runtime" "syscall" "testing" + "time" ) func (ln *TCPListener) port() string { @@ -696,3 +697,35 @@ func multicastRIBContains(ip IP) (bool, error) { } return false, nil } + +// Issue 21856. +func TestClosingListener(t *testing.T) { + ln, err := newLocalListener("tcp") + if err != nil { + t.Fatal(err) + } + addr := ln.Addr() + + go func() { + for { + c, err := ln.Accept() + if err != nil { + return + } + c.Close() + } + }() + + // Let the goroutine start. We don't sleep long: if the + // goroutine doesn't start, the test will pass without really + // testing anything, which is OK. + time.Sleep(time.Millisecond) + + ln.Close() + + ln2, err := Listen("tcp", addr.String()) + if err != nil { + t.Fatal(err) + } + ln2.Close() +} diff --git a/libgo/go/net/lookup_plan9.go b/libgo/go/net/lookup_plan9.go index f81e220fc8c..1037b81a3be 100644 --- a/libgo/go/net/lookup_plan9.go +++ b/libgo/go/net/lookup_plan9.go @@ -198,7 +198,7 @@ func (*Resolver) lookupPort(ctx context.Context, network, service string) (port func (*Resolver) lookupCNAME(ctx context.Context, name string) (cname string, err error) { lines, err := queryDNS(ctx, name, "cname") if err != nil { - if stringsHasSuffix(err.Error(), "dns failure") { + if stringsHasSuffix(err.Error(), "dns failure") || stringsHasSuffix(err.Error(), "resource does not exist; negrcode 0") { cname = name + "." err = nil } diff --git a/libgo/go/net/lookup_test.go b/libgo/go/net/lookup_test.go index 68a7abe95df..e3bf114a8e2 100644 --- a/libgo/go/net/lookup_test.go +++ b/libgo/go/net/lookup_test.go @@ -9,7 +9,9 @@ import ( "context" "fmt" "internal/testenv" + "reflect" "runtime" + "sort" "strings" "testing" "time" @@ -303,6 +305,28 @@ func TestLookupGoogleHost(t *testing.T) { } } +func TestLookupLongTXT(t *testing.T) { + if runtime.GOOS == "plan9" { + t.Skip("skipping on plan9; see https://golang.org/issue/22857") + } + if testenv.Builder() == "" { + testenv.MustHaveExternalNetwork(t) + } + + txts, err := LookupTXT("golang.rsc.io") + if err != nil { + t.Fatal(err) + } + sort.Strings(txts) + want := []string{ + strings.Repeat("abcdefghijklmnopqrstuvwxyABCDEFGHJIKLMNOPQRSTUVWXY", 10), + "gophers rule", + } + if !reflect.DeepEqual(txts, want) { + t.Fatalf("LookupTXT golang.rsc.io incorrect\nhave %q\nwant %q", txts, want) + } +} + var lookupGoogleIPTests = []struct { name string }{ diff --git a/libgo/go/net/lookup_windows.go b/libgo/go/net/lookup_windows.go index 0036d89d150..ac1f9b431ac 100644 --- a/libgo/go/net/lookup_windows.go +++ b/libgo/go/net/lookup_windows.go @@ -279,10 +279,11 @@ func (*Resolver) lookupTXT(ctx context.Context, name string) ([]string, error) { txts := make([]string, 0, 10) for _, p := range validRecs(r, syscall.DNS_TYPE_TEXT, name) { d := (*syscall.DNSTXTData)(unsafe.Pointer(&p.Data[0])) + s := "" for _, v := range (*[1 << 10]*uint16)(unsafe.Pointer(&(d.StringArray[0])))[:d.StringCount] { - s := syscall.UTF16ToString((*[1 << 20]uint16)(unsafe.Pointer(v))[:]) - txts = append(txts, s) + s += syscall.UTF16ToString((*[1 << 20]uint16)(unsafe.Pointer(v))[:]) } + txts = append(txts, s) } return txts, nil } diff --git a/libgo/go/net/mail/message.go b/libgo/go/net/mail/message.go index 45a995ec720..4f3184f3e8a 100644 --- a/libgo/go/net/mail/message.go +++ b/libgo/go/net/mail/message.go @@ -10,10 +10,10 @@ extended by RFC 6532. Notable divergences: * Obsolete address formats are not parsed, including addresses with embedded route information. - * Group addresses are not parsed. * The full range of spacing (the CFWS syntax element) is not supported, such as breaking addresses across lines. * No unicode normalization is performed. + * The special characters ()[]:;@\, are allowed to appear unquoted in names. */ package mail @@ -190,7 +190,7 @@ func (a *Address) String() string { // Add quotes if needed quoteLocal := false for i, r := range local { - if isAtext(r, false) { + if isAtext(r, false, false) { continue } if r == '.' { @@ -247,13 +247,15 @@ func (p *addrParser) parseAddressList() ([]*Address, error) { var list []*Address for { p.skipSpace() - addr, err := p.parseAddress() + addrs, err := p.parseAddress(true) if err != nil { return nil, err } - list = append(list, addr) + list = append(list, addrs...) - p.skipSpace() + if !p.skipCFWS() { + return nil, errors.New("mail: misformatted parenthetical comment") + } if p.empty() { break } @@ -265,36 +267,55 @@ func (p *addrParser) parseAddressList() ([]*Address, error) { } func (p *addrParser) parseSingleAddress() (*Address, error) { - addr, err := p.parseAddress() + addrs, err := p.parseAddress(true) if err != nil { return nil, err } - p.skipSpace() + if !p.skipCFWS() { + return nil, errors.New("mail: misformatted parenthetical comment") + } if !p.empty() { return nil, fmt.Errorf("mail: expected single address, got %q", p.s) } - return addr, nil + if len(addrs) == 0 { + return nil, errors.New("mail: empty group") + } + if len(addrs) > 1 { + return nil, errors.New("mail: group with multiple addresses") + } + return addrs[0], nil } // parseAddress parses a single RFC 5322 address at the start of p. -func (p *addrParser) parseAddress() (addr *Address, err error) { +func (p *addrParser) parseAddress(handleGroup bool) ([]*Address, error) { debug.Printf("parseAddress: %q", p.s) p.skipSpace() if p.empty() { return nil, errors.New("mail: no address") } - // address = name-addr / addr-spec - // TODO(dsymonds): Support parsing group address. + // address = mailbox / group + // mailbox = name-addr / addr-spec + // group = display-name ":" [group-list] ";" [CFWS] // addr-spec has a more restricted grammar than name-addr, // so try parsing it first, and fallback to name-addr. // TODO(dsymonds): Is this really correct? spec, err := p.consumeAddrSpec() if err == nil { - return &Address{ + var displayName string + p.skipSpace() + if !p.empty() && p.peek() == '(' { + displayName, err = p.consumeDisplayNameComment() + if err != nil { + return nil, err + } + } + + return []*Address{{ + Name: displayName, Address: spec, - }, err + }}, err } debug.Printf("parseAddress: not an addr-spec: %v", err) debug.Printf("parseAddress: state is now %q", p.s) @@ -309,8 +330,13 @@ func (p *addrParser) parseAddress() (addr *Address, err error) { } debug.Printf("parseAddress: displayName=%q", displayName) - // angle-addr = "<" addr-spec ">" p.skipSpace() + if handleGroup { + if p.consume(':') { + return p.consumeGroupList() + } + } + // angle-addr = "<" addr-spec ">" if !p.consume('<') { return nil, errors.New("mail: no angle-addr") } @@ -323,10 +349,42 @@ func (p *addrParser) parseAddress() (addr *Address, err error) { } debug.Printf("parseAddress: spec=%q", spec) - return &Address{ + return []*Address{{ Name: displayName, Address: spec, - }, nil + }}, nil +} + +func (p *addrParser) consumeGroupList() ([]*Address, error) { + var group []*Address + // handle empty group. + p.skipSpace() + if p.consume(';') { + p.skipCFWS() + return group, nil + } + + for { + p.skipSpace() + // embedded groups not allowed. + addrs, err := p.parseAddress(false) + if err != nil { + return nil, err + } + group = append(group, addrs...) + + if !p.skipCFWS() { + return nil, errors.New("mail: misformatted parenthetical comment") + } + if p.consume(';') { + p.skipCFWS() + break + } + if !p.consume(',') { + return nil, errors.New("mail: expected comma") + } + } + return group, nil } // consumeAddrSpec parses a single RFC 5322 addr-spec at the start of p. @@ -482,20 +540,20 @@ Loop: // consumeAtom parses an RFC 5322 atom at the start of p. // If dot is true, consumeAtom parses an RFC 5322 dot-atom instead. -// If permissive is true, consumeAtom will not fail on -// leading/trailing/double dots in the atom (see golang.org/issue/4938). +// If permissive is true, consumeAtom will not fail on: +// - leading/trailing/double dots in the atom (see golang.org/issue/4938) +// - special characters (RFC 5322 3.2.3) except '<', '>', ':' and '"' (see golang.org/issue/21018) func (p *addrParser) consumeAtom(dot bool, permissive bool) (atom string, err error) { i := 0 Loop: for { r, size := utf8.DecodeRuneInString(p.s[i:]) - switch { case size == 1 && r == utf8.RuneError: return "", fmt.Errorf("mail: invalid utf-8 in address: %q", p.s) - case size == 0 || !isAtext(r, dot): + case size == 0 || !isAtext(r, dot, permissive): break Loop default: @@ -522,6 +580,30 @@ Loop: return atom, nil } +func (p *addrParser) consumeDisplayNameComment() (string, error) { + if !p.consume('(') { + return "", errors.New("mail: comment does not start with (") + } + comment, ok := p.consumeComment() + if !ok { + return "", errors.New("mail: misformatted parenthetical comment") + } + + // TODO(stapelberg): parse quoted-string within comment + words := strings.FieldsFunc(comment, func(r rune) bool { return r == ' ' || r == '\t' }) + for idx, word := range words { + decoded, isEncoded, err := p.decodeRFC2047Word(word) + if err != nil { + return "", err + } + if isEncoded { + words[idx] = decoded + } + } + + return strings.Join(words, " "), nil +} + func (p *addrParser) consume(c byte) bool { if p.empty() || p.peek() != c { return false @@ -547,6 +629,51 @@ func (p *addrParser) len() int { return len(p.s) } +// skipCFWS skips CFWS as defined in RFC5322. +func (p *addrParser) skipCFWS() bool { + p.skipSpace() + + for { + if !p.consume('(') { + break + } + + if _, ok := p.consumeComment(); !ok { + return false + } + + p.skipSpace() + } + + return true +} + +func (p *addrParser) consumeComment() (string, bool) { + // '(' already consumed. + depth := 1 + + var comment string + for { + if p.empty() || depth == 0 { + break + } + + if p.peek() == '\\' && p.len() > 1 { + p.s = p.s[1:] + } else if p.peek() == '(' { + depth++ + } else if p.peek() == ')' { + depth-- + } + if depth > 0 { + comment += p.s[:1] + } + p.s = p.s[1:] + } + + return comment, depth == 0 +} + func (p *addrParser) decodeRFC2047Word(s string) (word string, isEncoded bool, err error) { if p.dec != nil { word, err = p.dec.Decode(s) @@ -580,12 +707,18 @@ func (e charsetError) Error() string { // isAtext reports whether r is an RFC 5322 atext character. // If dot is true, period is included. -func isAtext(r rune, dot bool) bool { +// If permissive is true, RFC 5322 3.2.3 specials is included, +// except '<', '>', ':' and '"'. +func isAtext(r rune, dot, permissive bool) bool { switch r { case '.': return dot - case '(', ')', '<', '>', '[', ']', ':', ';', '@', '\\', ',', '"': // RFC 5322 3.2.3. specials + // RFC 5322 3.2.3. specials + case '(', ')', '[', ']', ';', '@', '\\', ',': + return permissive + + case '<', '>', '"', ':': return false } return isVchar(r) diff --git a/libgo/go/net/mail/message_test.go b/libgo/go/net/mail/message_test.go index 2106a0b97d6..b19da52c423 100644 --- a/libgo/go/net/mail/message_test.go +++ b/libgo/go/net/mail/message_test.go @@ -129,14 +129,21 @@ func TestAddressParsingError(t *testing.T) { text string wantErrText string }{ - 0: {"=?iso-8859-2?Q?Bogl=E1rka_Tak=E1cs?= <unknown@gmail.com>", "charset not supported"}, - 1: {"a@gmail.com b@gmail.com", "expected single address"}, - 2: {string([]byte{0xed, 0xa0, 0x80}) + " <micro@example.net>", "invalid utf-8 in address"}, - 3: {"\"" + string([]byte{0xed, 0xa0, 0x80}) + "\" <half-surrogate@example.com>", "invalid utf-8 in quoted-string"}, - 4: {"\"\\" + string([]byte{0x80}) + "\" <escaped-invalid-unicode@example.net>", "invalid utf-8 in quoted-string"}, - 5: {"\"\x00\" <null@example.net>", "bad character in quoted-string"}, - 6: {"\"\\\x00\" <escaped-null@example.net>", "bad character in quoted-string"}, - 7: {"John Doe", "no angle-addr"}, + 0: {"=?iso-8859-2?Q?Bogl=E1rka_Tak=E1cs?= <unknown@gmail.com>", "charset not supported"}, + 1: {"a@gmail.com b@gmail.com", "expected single address"}, + 2: {string([]byte{0xed, 0xa0, 0x80}) + " <micro@example.net>", "invalid utf-8 in address"}, + 3: {"\"" + string([]byte{0xed, 0xa0, 0x80}) + "\" <half-surrogate@example.com>", "invalid utf-8 in quoted-string"}, + 4: {"\"\\" + string([]byte{0x80}) + "\" <escaped-invalid-unicode@example.net>", "invalid utf-8 in quoted-string"}, + 5: {"\"\x00\" <null@example.net>", "bad character in quoted-string"}, + 6: {"\"\\\x00\" <escaped-null@example.net>", "bad character in quoted-string"}, + 7: {"John Doe", "no angle-addr"}, + 8: {`<jdoe#machine.example>`, "missing @ in addr-spec"}, + 9: {`John <middle> Doe <jdoe@machine.example>`, "missing @ in addr-spec"}, + 10: {"cfws@example.com (", "misformatted parenthetical comment"}, + 11: {"empty group: ;", "empty group"}, + 12: {"root group: embed group: null@example.com;", "no angle-addr"}, + 13: {"group not closed: null@example.com", "expected comma"}, + 14: {"group: first@example.com, second@example.com;", "group with multiple addresses"}, } for i, tc := range mustErrTestCases { @@ -176,6 +183,34 @@ func TestAddressParsing(t *testing.T) { }}, }, { + `"John (middle) Doe" <jdoe@machine.example>`, + []*Address{{ + Name: "John (middle) Doe", + Address: "jdoe@machine.example", + }}, + }, + { + `John (middle) Doe <jdoe@machine.example>`, + []*Address{{ + Name: "John (middle) Doe", + Address: "jdoe@machine.example", + }}, + }, + { + `John !@M@! Doe <jdoe@machine.example>`, + []*Address{{ + Name: "John !@M@! Doe", + Address: "jdoe@machine.example", + }}, + }, + { + `"John <middle> Doe" <jdoe@machine.example>`, + []*Address{{ + Name: "John <middle> Doe", + Address: "jdoe@machine.example", + }}, + }, + { `Mary Smith <mary@x.test>, jdoe@example.org, Who? <one@y.test>`, []*Address{ { @@ -203,9 +238,62 @@ func TestAddressParsing(t *testing.T) { }, }, }, + // RFC 5322, Appendix A.6.1 + { + `Joe Q. Public <john.q.public@example.com>`, + []*Address{{ + Name: "Joe Q. Public", + Address: "john.q.public@example.com", + }}, + }, // RFC 5322, Appendix A.1.3 - // TODO(dsymonds): Group addresses. - + { + `group1: groupaddr1@example.com;`, + []*Address{ + { + Name: "", + Address: "groupaddr1@example.com", + }, + }, + }, + { + `empty group: ;`, + []*Address(nil), + }, + { + `A Group:Ed Jones <c@a.test>,joe@where.test,John <jdoe@one.test>;`, + []*Address{ + { + Name: "Ed Jones", + Address: "c@a.test", + }, + { + Name: "", + Address: "joe@where.test", + }, + { + Name: "John", + Address: "jdoe@one.test", + }, + }, + }, + { + `Group1: <addr1@example.com>;, Group 2: addr2@example.com;, John <addr3@example.com>`, + []*Address{ + { + Name: "", + Address: "addr1@example.com", + }, + { + Name: "", + Address: "addr2@example.com", + }, + { + Name: "John", + Address: "addr3@example.com", + }, + }, + }, // RFC 2047 "Q"-encoded ISO-8859-1 address. { `=?iso-8859-1?q?J=F6rg_Doe?= <joerg@example.com>`, @@ -336,6 +424,89 @@ func TestAddressParsing(t *testing.T) { }, }, }, + // CFWS + { + `<cfws@example.com> (CFWS (cfws)) (another comment)`, + []*Address{ + { + Name: "", + Address: "cfws@example.com", + }, + }, + }, + { + `<cfws@example.com> () (another comment), <cfws2@example.com> (another)`, + []*Address{ + { + Name: "", + Address: "cfws@example.com", + }, + { + Name: "", + Address: "cfws2@example.com", + }, + }, + }, + // Comment as display name + { + `john@example.com (John Doe)`, + []*Address{ + { + Name: "John Doe", + Address: "john@example.com", + }, + }, + }, + // Comment and display name + { + `John Doe <john@example.com> (Joey)`, + []*Address{ + { + Name: "John Doe", + Address: "john@example.com", + }, + }, + }, + // Comment as display name, no space + { + `john@example.com(John Doe)`, + []*Address{ + { + Name: "John Doe", + Address: "john@example.com", + }, + }, + }, + // Comment as display name, Q-encoded + { + `asjo@example.com (Adam =?utf-8?Q?Sj=C3=B8gren?=)`, + []*Address{ + { + Name: "Adam Sjøgren", + Address: "asjo@example.com", + }, + }, + }, + // Comment as display name, Q-encoded and tab-separated + { + `asjo@example.com (Adam =?utf-8?Q?Sj=C3=B8gren?=)`, + []*Address{ + { + Name: "Adam Sjøgren", + Address: "asjo@example.com", + }, + }, + }, + // Nested comment as display name, Q-encoded + { + `asjo@example.com (Adam =?utf-8?Q?Sj=C3=B8gren?= (Debian))`, + []*Address{ + { + Name: "Adam Sjøgren (Debian)", + Address: "asjo@example.com", + }, + }, + }, } for _, test := range tests { if len(test.exp) == 1 { diff --git a/libgo/go/net/main_windows_test.go b/libgo/go/net/main_windows_test.go index f38a3a0d668..07f21b72eb1 100644 --- a/libgo/go/net/main_windows_test.go +++ b/libgo/go/net/main_windows_test.go @@ -9,6 +9,7 @@ import "internal/poll" var ( // Placeholders for saving original socket system calls. origSocket = socketFunc + origWSASocket = wsaSocketFunc origClosesocket = poll.CloseFunc origConnect = connectFunc origConnectEx = poll.ConnectExFunc @@ -18,6 +19,7 @@ var ( func installTestHooks() { socketFunc = sw.Socket + wsaSocketFunc = sw.WSASocket poll.CloseFunc = sw.Closesocket connectFunc = sw.Connect poll.ConnectExFunc = sw.ConnectEx @@ -27,6 +29,7 @@ func installTestHooks() { func uninstallTestHooks() { socketFunc = origSocket + wsaSocketFunc = origWSASocket poll.CloseFunc = origClosesocket connectFunc = origConnect poll.ConnectExFunc = origConnectEx diff --git a/libgo/go/net/parse.go b/libgo/go/net/parse.go index b270159cd88..a2d9245348c 100644 --- a/libgo/go/net/parse.go +++ b/libgo/go/net/parse.go @@ -69,7 +69,7 @@ func open(name string) (*file, error) { if err != nil { return nil, err } - return &file{fd, make([]byte, 0, os.Getpagesize()), false}, nil + return &file{fd, make([]byte, 0, 64*1024), false}, nil } func stat(name string) (mtime time.Time, size int64, err error) { diff --git a/libgo/go/net/pipe.go b/libgo/go/net/pipe.go index 37e552f54e5..9177fc40364 100644 --- a/libgo/go/net/pipe.go +++ b/libgo/go/net/pipe.go @@ -5,63 +5,239 @@ package net import ( - "errors" "io" + "sync" "time" ) +// pipeDeadline is an abstraction for handling timeouts. +type pipeDeadline struct { + mu sync.Mutex // Guards timer and cancel + timer *time.Timer + cancel chan struct{} // Must be non-nil +} + +func makePipeDeadline() pipeDeadline { + return pipeDeadline{cancel: make(chan struct{})} +} + +// set sets the point in time when the deadline will time out. +// A timeout event is signaled by closing the channel returned by waiter. +// Once a timeout has occurred, the deadline can be refreshed by specifying a +// t value in the future. +// +// A zero value for t prevents timeout. +func (d *pipeDeadline) set(t time.Time) { + d.mu.Lock() + defer d.mu.Unlock() + + if d.timer != nil && !d.timer.Stop() { + <-d.cancel // Wait for the timer callback to finish and close cancel + } + d.timer = nil + + // Time is zero, then there is no deadline. + closed := isClosedChan(d.cancel) + if t.IsZero() { + if closed { + d.cancel = make(chan struct{}) + } + return + } + + // Time in the future, setup a timer to cancel in the future. + if dur := time.Until(t); dur > 0 { + if closed { + d.cancel = make(chan struct{}) + } + d.timer = time.AfterFunc(dur, func() { + close(d.cancel) + }) + return + } + + // Time in the past, so close immediately. + if !closed { + close(d.cancel) + } +} + +// wait returns a channel that is closed when the deadline is exceeded. +func (d *pipeDeadline) wait() chan struct{} { + d.mu.Lock() + defer d.mu.Unlock() + return d.cancel +} + +func isClosedChan(c <-chan struct{}) bool { + select { + case <-c: + return true + default: + return false + } +} + +type timeoutError struct{} + +func (timeoutError) Error() string { return "deadline exceeded" } +func (timeoutError) Timeout() bool { return true } +func (timeoutError) Temporary() bool { return true } + +type pipeAddr struct{} + +func (pipeAddr) Network() string { return "pipe" } +func (pipeAddr) String() string { return "pipe" } + +type pipe struct { + wrMu sync.Mutex // Serialize Write operations + + // Used by local Read to interact with remote Write. + // Successful receive on rdRx is always followed by send on rdTx. + rdRx <-chan []byte + rdTx chan<- int + + // Used by local Write to interact with remote Read. + // Successful send on wrTx is always followed by receive on wrRx. + wrTx chan<- []byte + wrRx <-chan int + + once sync.Once // Protects closing localDone + localDone chan struct{} + remoteDone <-chan struct{} + + readDeadline pipeDeadline + writeDeadline pipeDeadline +} + // Pipe creates a synchronous, in-memory, full duplex // network connection; both ends implement the Conn interface. // Reads on one end are matched with writes on the other, // copying data directly between the two; there is no internal // buffering. func Pipe() (Conn, Conn) { - r1, w1 := io.Pipe() - r2, w2 := io.Pipe() - - return &pipe{r1, w2}, &pipe{r2, w1} -} + cb1 := make(chan []byte) + cb2 := make(chan []byte) + cn1 := make(chan int) + cn2 := make(chan int) + done1 := make(chan struct{}) + done2 := make(chan struct{}) -type pipe struct { - *io.PipeReader - *io.PipeWriter + p1 := &pipe{ + rdRx: cb1, rdTx: cn1, + wrTx: cb2, wrRx: cn2, + localDone: done1, remoteDone: done2, + readDeadline: makePipeDeadline(), + writeDeadline: makePipeDeadline(), + } + p2 := &pipe{ + rdRx: cb2, rdTx: cn2, + wrTx: cb1, wrRx: cn1, + localDone: done2, remoteDone: done1, + readDeadline: makePipeDeadline(), + writeDeadline: makePipeDeadline(), + } + return p1, p2 } -type pipeAddr int +func (*pipe) LocalAddr() Addr { return pipeAddr{} } +func (*pipe) RemoteAddr() Addr { return pipeAddr{} } -func (pipeAddr) Network() string { - return "pipe" +func (p *pipe) Read(b []byte) (int, error) { + n, err := p.read(b) + if err != nil && err != io.EOF && err != io.ErrClosedPipe { + err = &OpError{Op: "read", Net: "pipe", Err: err} + } + return n, err } -func (pipeAddr) String() string { - return "pipe" -} +func (p *pipe) read(b []byte) (n int, err error) { + switch { + case isClosedChan(p.localDone): + return 0, io.ErrClosedPipe + case isClosedChan(p.remoteDone): + return 0, io.EOF + case isClosedChan(p.readDeadline.wait()): + return 0, timeoutError{} + } -func (p *pipe) Close() error { - err := p.PipeReader.Close() - err1 := p.PipeWriter.Close() - if err == nil { - err = err1 + select { + case bw := <-p.rdRx: + nr := copy(b, bw) + p.rdTx <- nr + return nr, nil + case <-p.localDone: + return 0, io.ErrClosedPipe + case <-p.remoteDone: + return 0, io.EOF + case <-p.readDeadline.wait(): + return 0, timeoutError{} } - return err } -func (p *pipe) LocalAddr() Addr { - return pipeAddr(0) +func (p *pipe) Write(b []byte) (int, error) { + n, err := p.write(b) + if err != nil && err != io.ErrClosedPipe { + err = &OpError{Op: "write", Net: "pipe", Err: err} + } + return n, err } -func (p *pipe) RemoteAddr() Addr { - return pipeAddr(0) +func (p *pipe) write(b []byte) (n int, err error) { + switch { + case isClosedChan(p.localDone): + return 0, io.ErrClosedPipe + case isClosedChan(p.remoteDone): + return 0, io.ErrClosedPipe + case isClosedChan(p.writeDeadline.wait()): + return 0, timeoutError{} + } + + p.wrMu.Lock() // Ensure entirety of b is written together + defer p.wrMu.Unlock() + for once := true; once || len(b) > 0; once = false { + select { + case p.wrTx <- b: + nw := <-p.wrRx + b = b[nw:] + n += nw + case <-p.localDone: + return n, io.ErrClosedPipe + case <-p.remoteDone: + return n, io.ErrClosedPipe + case <-p.writeDeadline.wait(): + return n, timeoutError{} + } + } + return n, nil } func (p *pipe) SetDeadline(t time.Time) error { - return &OpError{Op: "set", Net: "pipe", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} + if isClosedChan(p.localDone) || isClosedChan(p.remoteDone) { + return io.ErrClosedPipe + } + p.readDeadline.set(t) + p.writeDeadline.set(t) + return nil } func (p *pipe) SetReadDeadline(t time.Time) error { - return &OpError{Op: "set", Net: "pipe", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} + if isClosedChan(p.localDone) || isClosedChan(p.remoteDone) { + return io.ErrClosedPipe + } + p.readDeadline.set(t) + return nil } func (p *pipe) SetWriteDeadline(t time.Time) error { - return &OpError{Op: "set", Net: "pipe", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} + if isClosedChan(p.localDone) || isClosedChan(p.remoteDone) { + return io.ErrClosedPipe + } + p.writeDeadline.set(t) + return nil +} + +func (p *pipe) Close() error { + p.once.Do(func() { close(p.localDone) }) + return nil } diff --git a/libgo/go/net/pipe_test.go b/libgo/go/net/pipe_test.go index e3172d882fb..84a71b756bc 100644 --- a/libgo/go/net/pipe_test.go +++ b/libgo/go/net/pipe_test.go @@ -2,54 +2,48 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package net +package net_test import ( - "bytes" "io" + "net" "testing" + "time" + + "golang_org/x/net/nettest" ) -func checkPipeWrite(t *testing.T, w io.Writer, data []byte, c chan int) { - n, err := w.Write(data) - if err != nil { - t.Error(err) - } - if n != len(data) { - t.Errorf("short write: %d != %d", n, len(data)) - } - c <- 0 +func TestPipe(t *testing.T) { + nettest.TestConn(t, func() (c1, c2 net.Conn, stop func(), err error) { + c1, c2 = net.Pipe() + stop = func() { + c1.Close() + c2.Close() + } + return + }) } -func checkPipeRead(t *testing.T, r io.Reader, data []byte, wantErr error) { - buf := make([]byte, len(data)+10) - n, err := r.Read(buf) - if err != wantErr { - t.Error(err) - return +func TestPipeCloseError(t *testing.T) { + c1, c2 := net.Pipe() + c1.Close() + + if _, err := c1.Read(nil); err != io.ErrClosedPipe { + t.Errorf("c1.Read() = %v, want io.ErrClosedPipe", err) } - if n != len(data) || !bytes.Equal(buf[0:n], data) { - t.Errorf("bad read: got %q", buf[0:n]) - return + if _, err := c1.Write(nil); err != io.ErrClosedPipe { + t.Errorf("c1.Write() = %v, want io.ErrClosedPipe", err) + } + if err := c1.SetDeadline(time.Time{}); err != io.ErrClosedPipe { + t.Errorf("c1.SetDeadline() = %v, want io.ErrClosedPipe", err) + } + if _, err := c2.Read(nil); err != io.EOF { + t.Errorf("c2.Read() = %v, want io.EOF", err) + } + if _, err := c2.Write(nil); err != io.ErrClosedPipe { + t.Errorf("c2.Write() = %v, want io.ErrClosedPipe", err) + } + if err := c2.SetDeadline(time.Time{}); err != io.ErrClosedPipe { + t.Errorf("c2.SetDeadline() = %v, want io.ErrClosedPipe", err) } -} - -// TestPipe tests a simple read/write/close sequence. -// Assumes that the underlying io.Pipe implementation -// is solid and we're just testing the net wrapping. -func TestPipe(t *testing.T) { - c := make(chan int) - cli, srv := Pipe() - go checkPipeWrite(t, cli, []byte("hello, world"), c) - checkPipeRead(t, srv, []byte("hello, world"), nil) - <-c - go checkPipeWrite(t, srv, []byte("line 2"), c) - checkPipeRead(t, cli, []byte("line 2"), nil) - <-c - go checkPipeWrite(t, cli, []byte("a third line"), c) - checkPipeRead(t, srv, []byte("a third line"), nil) - <-c - go srv.Close() - checkPipeRead(t, cli, nil, io.EOF) - cli.Close() } diff --git a/libgo/go/net/platform_test.go b/libgo/go/net/platform_test.go index 5841ca35a00..8e7d9151dee 100644 --- a/libgo/go/net/platform_test.go +++ b/libgo/go/net/platform_test.go @@ -43,9 +43,12 @@ func testableNetwork(network string) bool { case "unixpacket": switch runtime.GOOS { case "android", "darwin", "nacl", "plan9", "windows": - fallthrough - case "freebsd": // FreeBSD 8 and below don't support unixpacket return false + case "netbsd": + // It passes on amd64 at least. 386 fails (Issue 22927). arm is unknown. + if runtime.GOARCH == "386" { + return false + } } } switch ss[0] { @@ -149,12 +152,19 @@ func testableListenArgs(network, address, client string) bool { return true } -var condFatalf = func() func(*testing.T, string, ...interface{}) { - // A few APIs, File, Read/WriteMsg{UDP,IP}, are not - // implemented yet on both Plan 9 and Windows. +func condFatalf(t *testing.T, network string, format string, args ...interface{}) { + t.Helper() + // A few APIs like File and Read/WriteMsg{UDP,IP} are not + // fully implemented yet on Plan 9 and Windows. switch runtime.GOOS { - case "plan9", "windows": - return (*testing.T).Logf + case "windows": + if network == "file+net" { + t.Logf(format, args...) + return + } + case "plan9": + t.Logf(format, args...) + return } - return (*testing.T).Fatalf -}() + t.Fatalf(format, args...) +} diff --git a/libgo/go/net/port.go b/libgo/go/net/port.go index 8e1321afa44..32e76286193 100644 --- a/libgo/go/net/port.go +++ b/libgo/go/net/port.go @@ -4,12 +4,12 @@ package net -// parsePort parses service as a decimal interger and returns the +// parsePort parses service as a decimal integer and returns the // corresponding value as port. It is the caller's responsibility to // parse service as a non-decimal integer when needsLookup is true. // // Some system resolvers will return a valid port number when given a number -// over 65536 (see https://github.com/golang/go/issues/11715). Alas, the parser +// over 65536 (see https://golang.org/issues/11715). Alas, the parser // can't bail early on numbers > 65536. Therefore reasonably large/small // numbers are parsed in full and rejected if invalid. func parsePort(service string) (port int, needsLookup bool) { diff --git a/libgo/go/net/protoconn_test.go b/libgo/go/net/protoconn_test.go index 23589d3ca87..05c45d02b9a 100644 --- a/libgo/go/net/protoconn_test.go +++ b/libgo/go/net/protoconn_test.go @@ -54,7 +54,7 @@ func TestTCPListenerSpecificMethods(t *testing.T) { } if f, err := ln.File(); err != nil { - condFatalf(t, "%v", err) + condFatalf(t, "file+net", "%v", err) } else { f.Close() } @@ -139,14 +139,14 @@ func TestUDPConnSpecificMethods(t *testing.T) { t.Fatal(err) } if _, _, err := c.WriteMsgUDP(wb, nil, c.LocalAddr().(*UDPAddr)); err != nil { - condFatalf(t, "%v", err) + condFatalf(t, c.LocalAddr().Network(), "%v", err) } if _, _, _, _, err := c.ReadMsgUDP(rb, nil); err != nil { - condFatalf(t, "%v", err) + condFatalf(t, c.LocalAddr().Network(), "%v", err) } if f, err := c.File(); err != nil { - condFatalf(t, "%v", err) + condFatalf(t, "file+net", "%v", err) } else { f.Close() } @@ -184,7 +184,7 @@ func TestIPConnSpecificMethods(t *testing.T) { c.SetWriteBuffer(2048) if f, err := c.File(); err != nil { - condFatalf(t, "%v", err) + condFatalf(t, "file+net", "%v", err) } else { f.Close() } diff --git a/libgo/go/net/rawconn.go b/libgo/go/net/rawconn.go index d67be644a34..2399c9f31dd 100644 --- a/libgo/go/net/rawconn.go +++ b/libgo/go/net/rawconn.go @@ -60,3 +60,19 @@ func (c *rawConn) Write(f func(uintptr) bool) error { func newRawConn(fd *netFD) (*rawConn, error) { return &rawConn{fd: fd}, nil } + +type rawListener struct { + rawConn +} + +func (l *rawListener) Read(func(uintptr) bool) error { + return syscall.EINVAL +} + +func (l *rawListener) Write(func(uintptr) bool) error { + return syscall.EINVAL +} + +func newRawListener(fd *netFD) (*rawListener, error) { + return &rawListener{rawConn{fd: fd}}, nil +} diff --git a/libgo/go/net/rawconn_unix_test.go b/libgo/go/net/rawconn_unix_test.go index 294249ba5d1..913ad865951 100644 --- a/libgo/go/net/rawconn_unix_test.go +++ b/libgo/go/net/rawconn_unix_test.go @@ -92,3 +92,53 @@ func TestRawConn(t *testing.T) { t.Fatal("should fail") } } + +func TestRawConnListener(t *testing.T) { + ln, err := newLocalListener("tcp") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + cc, err := ln.(*TCPListener).SyscallConn() + if err != nil { + t.Fatal(err) + } + + called := false + op := func(uintptr) bool { + called = true + return true + } + + err = cc.Write(op) + if err == nil { + t.Error("Write should return an error") + } + if called { + t.Error("Write shouldn't call op") + } + + called = false + err = cc.Read(op) + if err == nil { + t.Error("Read should return an error") + } + if called { + t.Error("Read shouldn't call op") + } + + var operr error + fn := func(s uintptr) { + _, operr = syscall.GetsockoptInt(int(s), syscall.SOL_SOCKET, syscall.SO_REUSEADDR) + } + err = cc.Control(fn) + if err != nil || operr != nil { + t.Fatal(err, operr) + } + ln.Close() + err = cc.Control(fn) + if err == nil { + t.Fatal("Control after Close should fail") + } +} diff --git a/libgo/go/net/rawconn_windows_test.go b/libgo/go/net/rawconn_windows_test.go index 5fb6de75393..2ee12c35963 100644 --- a/libgo/go/net/rawconn_windows_test.go +++ b/libgo/go/net/rawconn_windows_test.go @@ -7,6 +7,7 @@ package net import ( "syscall" "testing" + "unsafe" ) func TestRawConn(t *testing.T) { @@ -34,3 +35,55 @@ func TestRawConn(t *testing.T) { t.Fatal("should fail") } } + +func TestRawConnListener(t *testing.T) { + ln, err := newLocalListener("tcp") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + cc, err := ln.(*TCPListener).SyscallConn() + if err != nil { + t.Fatal(err) + } + + called := false + op := func(uintptr) bool { + called = true + return true + } + + err = cc.Write(op) + if err == nil { + t.Error("Write should return an error") + } + if called { + t.Error("Write shouldn't call op") + } + + called = false + err = cc.Read(op) + if err == nil { + t.Error("Read should return an error") + } + if called { + t.Error("Read shouldn't call op") + } + + var operr error + fn := func(s uintptr) { + var v, l int32 + l = int32(unsafe.Sizeof(v)) + operr = syscall.Getsockopt(syscall.Handle(s), syscall.SOL_SOCKET, syscall.SO_REUSEADDR, (*byte)(unsafe.Pointer(&v)), &l) + } + err = cc.Control(fn) + if err != nil || operr != nil { + t.Fatal(err, operr) + } + ln.Close() + err = cc.Control(fn) + if err == nil { + t.Fatal("Control after Close should fail") + } +} diff --git a/libgo/go/net/rpc/server.go b/libgo/go/net/rpc/server.go index 29aae7ee7ff..a0212926037 100644 --- a/libgo/go/net/rpc/server.go +++ b/libgo/go/net/rpc/server.go @@ -372,7 +372,10 @@ func (m *methodType) NumCalls() (n uint) { return n } -func (s *service) call(server *Server, sending *sync.Mutex, mtype *methodType, req *Request, argv, replyv reflect.Value, codec ServerCodec) { +func (s *service) call(server *Server, sending *sync.Mutex, wg *sync.WaitGroup, mtype *methodType, req *Request, argv, replyv reflect.Value, codec ServerCodec) { + if wg != nil { + defer wg.Done() + } mtype.Lock() mtype.numCalls++ mtype.Unlock() @@ -456,6 +459,7 @@ func (server *Server) ServeConn(conn io.ReadWriteCloser) { // decode requests and encode responses. func (server *Server) ServeCodec(codec ServerCodec) { sending := new(sync.Mutex) + wg := new(sync.WaitGroup) for { service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec) if err != nil { @@ -472,8 +476,12 @@ func (server *Server) ServeCodec(codec ServerCodec) { } continue } - go service.call(server, sending, mtype, req, argv, replyv, codec) + wg.Add(1) + go service.call(server, sending, wg, mtype, req, argv, replyv, codec) } + // We've seen that there are no more requests. + // Wait for responses to be sent before closing codec. + wg.Wait() codec.Close() } @@ -493,7 +501,7 @@ func (server *Server) ServeRequest(codec ServerCodec) error { } return err } - service.call(server, sending, mtype, req, argv, replyv, codec) + service.call(server, sending, nil, mtype, req, argv, replyv, codec) return nil } diff --git a/libgo/go/net/rpc/server_test.go b/libgo/go/net/rpc/server_test.go index fb97f82a2f7..e5d7fe0c8f5 100644 --- a/libgo/go/net/rpc/server_test.go +++ b/libgo/go/net/rpc/server_test.go @@ -75,6 +75,11 @@ func (t *Arith) Error(args *Args, reply *Reply) error { panic("ERROR") } +func (t *Arith) SleepMilli(args *Args, reply *Reply) error { + time.Sleep(time.Duration(args.A) * time.Millisecond) + return nil +} + type hidden int func (t *hidden) Exported(args Args, reply *Reply) error { @@ -693,6 +698,53 @@ func TestAcceptExitAfterListenerClose(t *testing.T) { newServer.Accept(l) } +func TestShutdown(t *testing.T) { + var l net.Listener + l, _ = listenTCP() + ch := make(chan net.Conn, 1) + go func() { + defer l.Close() + c, err := l.Accept() + if err != nil { + t.Error(err) + } + ch <- c + }() + c, err := net.Dial("tcp", l.Addr().String()) + if err != nil { + t.Fatal(err) + } + c1 := <-ch + if c1 == nil { + t.Fatal(err) + } + + newServer := NewServer() + newServer.Register(new(Arith)) + go newServer.ServeConn(c1) + + args := &Args{7, 8} + reply := new(Reply) + client := NewClient(c) + err = client.Call("Arith.Add", args, reply) + if err != nil { + t.Fatal(err) + } + + // On an unloaded system 10ms is usually enough to fail 100% of the time + // with a broken server. On a loaded system, a broken server might incorrectly + // be reported as passing, but we're OK with that kind of flakiness. + // If the code is correct, this test will never fail, regardless of timeout. + args.A = 10 // 10 ms + done := make(chan *Call, 1) + call := client.Go("Arith.SleepMilli", args, reply, done) + c.(*net.TCPConn).CloseWrite() + <-done + if call.Error != nil { + t.Fatal(err) + } +} + func benchmarkEndToEnd(dial func() (*Client, error), b *testing.B) { once.Do(startServer) client, err := dial() diff --git a/libgo/go/net/smtp/auth.go b/libgo/go/net/smtp/auth.go index 3f1339ebc56..fd1a472f930 100644 --- a/libgo/go/net/smtp/auth.go +++ b/libgo/go/net/smtp/auth.go @@ -44,26 +44,29 @@ type plainAuth struct { } // PlainAuth returns an Auth that implements the PLAIN authentication -// mechanism as defined in RFC 4616. -// The returned Auth uses the given username and password to authenticate -// on TLS connections to host and act as identity. Usually identity will be -// left blank to act as username. +// mechanism as defined in RFC 4616. The returned Auth uses the given +// username and password to authenticate to host and act as identity. +// Usually identity should be the empty string, to act as username. +// +// PlainAuth will only send the credentials if the connection is using TLS +// or is connected to localhost. Otherwise authentication will fail with an +// error, without sending the credentials. func PlainAuth(identity, username, password, host string) Auth { return &plainAuth{identity, username, password, host} } +func isLocalhost(name string) bool { + return name == "localhost" || name == "127.0.0.1" || name == "::1" +} + func (a *plainAuth) Start(server *ServerInfo) (string, []byte, error) { - if !server.TLS { - advertised := false - for _, mechanism := range server.Auth { - if mechanism == "PLAIN" { - advertised = true - break - } - } - if !advertised { - return "", nil, errors.New("unencrypted connection") - } + // Must have TLS, or else localhost server. + // Note: If TLS is not true, then we can't trust ANYTHING in ServerInfo. + // In particular, it doesn't matter if the server advertises PLAIN auth. + // That might just be the attacker saying + // "it's ok, you can trust me with your password." + if !server.TLS && !isLocalhost(server.Name) { + return "", nil, errors.New("unencrypted connection") } if server.Name != a.host { return "", nil, errors.New("wrong host name") diff --git a/libgo/go/net/smtp/smtp.go b/libgo/go/net/smtp/smtp.go index 28472e447b5..cf699e6be82 100644 --- a/libgo/go/net/smtp/smtp.go +++ b/libgo/go/net/smtp/smtp.go @@ -67,6 +67,7 @@ func NewClient(conn net.Conn, host string) (*Client, error) { return nil, err } c := &Client{Text: text, conn: conn, serverName: host, localName: "localhost"} + _, c.tls = conn.(*tls.Conn) return c, nil } @@ -93,6 +94,9 @@ func (c *Client) hello() error { // automatically otherwise. If Hello is called, it must be called before // any of the other methods. func (c *Client) Hello(localName string) error { + if err := validateLine(localName); err != nil { + return err + } if c.didHello { return errors.New("smtp: Hello called after other methods") } @@ -179,6 +183,9 @@ func (c *Client) TLSConnectionState() (state tls.ConnectionState, ok bool) { // does not necessarily indicate an invalid address. Many servers // will not verify addresses for security reasons. func (c *Client) Verify(addr string) error { + if err := validateLine(addr); err != nil { + return err + } if err := c.hello(); err != nil { return err } @@ -237,6 +244,9 @@ func (c *Client) Auth(a Auth) error { // parameter. // This initiates a mail transaction and is followed by one or more Rcpt calls. func (c *Client) Mail(from string) error { + if err := validateLine(from); err != nil { + return err + } if err := c.hello(); err != nil { return err } @@ -254,6 +264,9 @@ func (c *Client) Mail(from string) error { // A call to Rcpt must be preceded by a call to Mail and may be followed by // a Data call or another Rcpt call. func (c *Client) Rcpt(to string) error { + if err := validateLine(to); err != nil { + return err + } _, _, err := c.cmd(25, "RCPT TO:<%s>", to) return err } @@ -304,6 +317,14 @@ var testHookStartTLS func(*tls.Config) // nil, except for tests // functionality. Higher-level packages exist outside of the standard // library. func SendMail(addr string, a Auth, from string, to []string, msg []byte) error { + if err := validateLine(from); err != nil { + return err + } + for _, recp := range to { + if err := validateLine(recp); err != nil { + return err + } + } c, err := Dial(addr) if err != nil { return err @@ -377,6 +398,16 @@ func (c *Client) Reset() error { return err } +// Noop sends the NOOP command to the server. It does nothing but check +// that the connection to the server is okay. +func (c *Client) Noop() error { + if err := c.hello(); err != nil { + return err + } + _, _, err := c.cmd(250, "NOOP") + return err +} + // Quit sends the QUIT command and closes the connection to the server. func (c *Client) Quit() error { if err := c.hello(); err != nil { @@ -388,3 +419,11 @@ func (c *Client) Quit() error { } return c.Text.Close() } + +// validateLine checks to see if a line has CR or LF as per RFC 5321 +func validateLine(line string) error { + if strings.ContainsAny(line, "\n\r") { + return errors.New("smtp: A line must not contain CR or LF") + } + return nil +} diff --git a/libgo/go/net/smtp/smtp_test.go b/libgo/go/net/smtp/smtp_test.go index 9dbe3eb9ecb..d489922597e 100644 --- a/libgo/go/net/smtp/smtp_test.go +++ b/libgo/go/net/smtp/smtp_test.go @@ -62,29 +62,41 @@ testLoop: } func TestAuthPlain(t *testing.T) { - auth := PlainAuth("foo", "bar", "baz", "servername") tests := []struct { - server *ServerInfo - err string + authName string + server *ServerInfo + err string }{ { - server: &ServerInfo{Name: "servername", TLS: true}, + authName: "servername", + server: &ServerInfo{Name: "servername", TLS: true}, }, { - // Okay; explicitly advertised by server. - server: &ServerInfo{Name: "servername", Auth: []string{"PLAIN"}}, + // OK to use PlainAuth on localhost without TLS + authName: "localhost", + server: &ServerInfo{Name: "localhost", TLS: false}, }, { - server: &ServerInfo{Name: "servername", Auth: []string{"CRAM-MD5"}}, - err: "unencrypted connection", + // NOT OK on non-localhost, even if server says PLAIN is OK. + // (We don't know that the server is the real server.) + authName: "servername", + server: &ServerInfo{Name: "servername", Auth: []string{"PLAIN"}}, + err: "unencrypted connection", }, { - server: &ServerInfo{Name: "attacker", TLS: true}, - err: "wrong host name", + authName: "servername", + server: &ServerInfo{Name: "servername", Auth: []string{"CRAM-MD5"}}, + err: "unencrypted connection", + }, + { + authName: "servername", + server: &ServerInfo{Name: "attacker", TLS: true}, + err: "wrong host name", }, } for i, tt := range tests { + auth := PlainAuth("foo", "bar", "baz", tt.authName) _, _, err := auth.Start(tt.server) got := "" if err != nil { @@ -182,6 +194,9 @@ func TestBasic(t *testing.T) { if err := c.Verify("user1@gmail.com"); err == nil { t.Fatalf("First VRFY: expected no verification") } + if err := c.Verify("user2@gmail.com>\r\nDATA\r\nAnother injected message body\r\n.\r\nQUIT\r\n"); err == nil { + t.Fatalf("VRFY should have failed due to a message injection attempt") + } if err := c.Verify("user2@gmail.com"); err != nil { t.Fatalf("Second VRFY: expected verification, got %s", err) } @@ -193,6 +208,12 @@ func TestBasic(t *testing.T) { t.Fatalf("AUTH failed: %s", err) } + if err := c.Rcpt("golang-nuts@googlegroups.com>\r\nDATA\r\nInjected message body\r\n.\r\nQUIT\r\n"); err == nil { + t.Fatalf("RCPT should have failed due to a message injection attempt") + } + if err := c.Mail("user@gmail.com>\r\nDATA\r\nAnother injected message body\r\n.\r\nQUIT\r\n"); err == nil { + t.Fatalf("MAIL should have failed due to a message injection attempt") + } if err := c.Mail("user@gmail.com"); err != nil { t.Fatalf("MAIL failed: %s", err) } @@ -352,6 +373,53 @@ HELO localhost QUIT ` +func TestNewClientWithTLS(t *testing.T) { + cert, err := tls.X509KeyPair(localhostCert, localhostKey) + if err != nil { + t.Fatalf("loadcert: %v", err) + } + + config := tls.Config{Certificates: []tls.Certificate{cert}} + + ln, err := tls.Listen("tcp", "127.0.0.1:0", &config) + if err != nil { + ln, err = tls.Listen("tcp", "[::1]:0", &config) + if err != nil { + t.Fatalf("server: listen: %v", err) + } + } + + go func() { + conn, err := ln.Accept() + if err != nil { + t.Errorf("server: accept: %v", err) + return + } + defer conn.Close() + + _, err = conn.Write([]byte("220 SIGNS\r\n")) + if err != nil { + t.Errorf("server: write: %v", err) + return + } + }() + + config.InsecureSkipVerify = true + conn, err := tls.Dial("tcp", ln.Addr().String(), &config) + if err != nil { + t.Fatalf("client: dial: %v", err) + } + defer conn.Close() + + client, err := NewClient(conn, ln.Addr().String()) + if err != nil { + t.Fatalf("smtp: newclient: %v", err) + } + if !client.tls { + t.Errorf("client.tls Got: %t Expected: %t", client.tls, true) + } +} + func TestHello(t *testing.T) { if len(helloServer) != len(helloClient) { @@ -375,6 +443,10 @@ func TestHello(t *testing.T) { switch i { case 0: + err = c.Hello("hostinjection>\n\rDATA\r\nInjected message body\r\n.\r\nQUIT\r\n") + if err == nil { + t.Errorf("Expected Hello to be rejected due to a message injection attempt") + } err = c.Hello("customhost") case 1: err = c.StartTLS(nil) @@ -406,6 +478,8 @@ func TestHello(t *testing.T) { t.Errorf("Want error, got none") } } + case 9: + err = c.Noop() default: t.Fatalf("Unhandled command") } @@ -438,6 +512,7 @@ var helloServer = []string{ "250 Reset ok\n", "221 Goodbye\n", "250 Sender ok\n", + "250 ok\n", } var baseHelloClient = `EHLO customhost @@ -454,6 +529,7 @@ var helloClient = []string{ "RSET\n", "QUIT\n", "VRFY test@example.com\n", + "NOOP\n", } func TestSendMail(t *testing.T) { @@ -506,6 +582,16 @@ func TestSendMail(t *testing.T) { } }(strings.Split(server, "\r\n")) + err = SendMail(l.Addr().String(), nil, "test@example.com", []string{"other@example.com>\n\rDATA\r\nInjected message body\r\n.\r\nQUIT\r\n"}, []byte(strings.Replace(`From: test@example.com +To: other@example.com +Subject: SendMail test + +SendMail is working for me. +`, "\n", "\r\n", -1))) + if err == nil { + t.Errorf("Expected SendMail to be rejected due to a message injection attempt") + } + err = SendMail(l.Addr().String(), nil, "test@example.com", []string{"other@example.com"}, []byte(strings.Replace(`From: test@example.com To: other@example.com Subject: SendMail test diff --git a/libgo/go/net/sock_bsd.go b/libgo/go/net/sock_bsd.go index 4e0e9e01f2c..dfb09205502 100644 --- a/libgo/go/net/sock_bsd.go +++ b/libgo/go/net/sock_bsd.go @@ -17,8 +17,10 @@ func maxListenerBacklog() int { err error ) switch runtime.GOOS { - case "darwin", "freebsd": + case "darwin": n, err = syscall.SysctlUint32("kern.ipc.somaxconn") + case "freebsd": + n, err = syscall.SysctlUint32("kern.ipc.acceptqueue") case "netbsd": // NOTE: NetBSD has no somaxconn-like kernel state so far case "openbsd": diff --git a/libgo/go/net/sock_windows.go b/libgo/go/net/sock_windows.go index 89a3ca42585..fa11c7af2e7 100644 --- a/libgo/go/net/sock_windows.go +++ b/libgo/go/net/sock_windows.go @@ -5,6 +5,7 @@ package net import ( + "internal/syscall/windows" "os" "syscall" ) @@ -16,9 +17,19 @@ func maxListenerBacklog() int { } func sysSocket(family, sotype, proto int) (syscall.Handle, error) { + s, err := wsaSocketFunc(int32(family), int32(sotype), int32(proto), + nil, 0, windows.WSA_FLAG_OVERLAPPED|windows.WSA_FLAG_NO_HANDLE_INHERIT) + if err == nil { + return s, nil + } + // WSA_FLAG_NO_HANDLE_INHERIT flag is not supported on some + // old versions of Windows, see + // https://msdn.microsoft.com/en-us/library/windows/desktop/ms742212(v=vs.85).aspx + // for details. Just use syscall.Socket, if windows.WSASocket failed. + // See ../syscall/exec_unix.go for description of ForkLock. syscall.ForkLock.RLock() - s, err := socketFunc(family, sotype, proto) + s, err = socketFunc(family, sotype, proto) if err == nil { syscall.CloseOnExec(s) } diff --git a/libgo/go/net/sockoptip_bsd.go b/libgo/go/net/sockoptip_bsdvar.go index b11f3a4edbe..95601013987 100644 --- a/libgo/go/net/sockoptip_bsd.go +++ b/libgo/go/net/sockoptip_bsdvar.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin dragonfly freebsd netbsd openbsd +// +build darwin dragonfly freebsd netbsd openbsd solaris package net diff --git a/libgo/go/net/sockoptip_posix.go b/libgo/go/net/sockoptip_posix.go index 92af7646ef9..b14963ff32e 100644 --- a/libgo/go/net/sockoptip_posix.go +++ b/libgo/go/net/sockoptip_posix.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build aix darwin dragonfly freebsd linux netbsd openbsd windows +// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris windows package net diff --git a/libgo/go/net/sockoptip_stub.go b/libgo/go/net/sockoptip_stub.go index f698687514d..fc20a9fc331 100644 --- a/libgo/go/net/sockoptip_stub.go +++ b/libgo/go/net/sockoptip_stub.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build nacl solaris +// +build nacl package net diff --git a/libgo/go/net/tcpsock.go b/libgo/go/net/tcpsock.go index e957aa3005a..9528140b940 100644 --- a/libgo/go/net/tcpsock.go +++ b/libgo/go/net/tcpsock.go @@ -225,6 +225,18 @@ type TCPListener struct { fd *netFD } +// SyscallConn returns a raw network connection. +// This implements the syscall.Conn interface. +// +// The returned RawConn only supports calling Control. Read and +// Write return an error. +func (l *TCPListener) SyscallConn() (syscall.RawConn, error) { + if !l.ok() { + return nil, syscall.EINVAL + } + return newRawListener(l.fd) +} + // AcceptTCP accepts the next incoming call and returns the new // connection. func (l *TCPListener) AcceptTCP() (*TCPConn, error) { diff --git a/libgo/go/net/tcpsock_test.go b/libgo/go/net/tcpsock_test.go index 660f4249d40..04b38b60741 100644 --- a/libgo/go/net/tcpsock_test.go +++ b/libgo/go/net/tcpsock_test.go @@ -8,6 +8,7 @@ import ( "fmt" "internal/testenv" "io" + "os" "reflect" "runtime" "sync" @@ -727,3 +728,74 @@ func TestTCPBig(t *testing.T) { }) } } + +func TestCopyPipeIntoTCP(t *testing.T) { + ln, err := newLocalListener("tcp") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + errc := make(chan error, 1) + defer func() { + if err := <-errc; err != nil { + t.Error(err) + } + }() + go func() { + c, err := ln.Accept() + if err != nil { + errc <- err + return + } + defer c.Close() + + buf := make([]byte, 100) + n, err := io.ReadFull(c, buf) + if err != io.ErrUnexpectedEOF || n != 2 { + errc <- fmt.Errorf("got err=%q n=%v; want err=%q n=2", err, n, io.ErrUnexpectedEOF) + return + } + + errc <- nil + }() + + c, err := Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer c.Close() + + r, w, err := os.Pipe() + if err != nil { + t.Fatal(err) + } + defer r.Close() + + errc2 := make(chan error, 1) + defer func() { + if err := <-errc2; err != nil { + t.Error(err) + } + }() + + defer w.Close() + + go func() { + _, err := io.Copy(c, r) + errc2 <- err + }() + + // Split write into 2 packets. That makes Windows TransmitFile + // drop second packet. + packet := make([]byte, 1) + _, err = w.Write(packet) + if err != nil { + t.Fatal(err) + } + time.Sleep(100 * time.Millisecond) + _, err = w.Write(packet) + if err != nil { + t.Fatal(err) + } +} diff --git a/libgo/go/net/textproto/reader.go b/libgo/go/net/textproto/reader.go index e07d1d62e09..8c3a05264a4 100644 --- a/libgo/go/net/textproto/reader.go +++ b/libgo/go/net/textproto/reader.go @@ -476,15 +476,25 @@ func (r *Reader) ReadMIMEHeader() (MIMEHeader, error) { } m := make(MIMEHeader, hint) + + // The first line cannot start with a leading space. + if buf, err := r.R.Peek(1); err == nil && (buf[0] == ' ' || buf[0] == '\t') { + line, err := r.readLineSlice() + if err != nil { + return m, err + } + return m, ProtocolError("malformed MIME header initial line: " + string(line)) + } + for { kv, err := r.readContinuedLineSlice() if len(kv) == 0 { return m, err } - // Key ends at first colon; should not have spaces but - // they appear in the wild, violating specs, so we - // remove them if present. + // Key ends at first colon; should not have trailing spaces + // but they appear in the wild, violating specs, so we remove + // them if present. i := bytes.IndexByte(kv, ':') if i < 0 { return m, ProtocolError("malformed MIME header line: " + string(kv)) diff --git a/libgo/go/net/textproto/reader_test.go b/libgo/go/net/textproto/reader_test.go index 0c53d48b74a..c6a6ced6493 100644 --- a/libgo/go/net/textproto/reader_test.go +++ b/libgo/go/net/textproto/reader_test.go @@ -211,6 +211,24 @@ func TestReadMIMEHeaderNonCompliant(t *testing.T) { } } +func TestReadMIMEHeaderMalformed(t *testing.T) { + inputs := []string{ + "No colon first line\r\nFoo: foo\r\n\r\n", + " No colon first line with leading space\r\nFoo: foo\r\n\r\n", + "\tNo colon first line with leading tab\r\nFoo: foo\r\n\r\n", + " First: line with leading space\r\nFoo: foo\r\n\r\n", + "\tFirst: line with leading tab\r\nFoo: foo\r\n\r\n", + "Foo: foo\r\nNo colon second line\r\n\r\n", + } + + for _, input := range inputs { + r := reader(input) + if m, err := r.ReadMIMEHeader(); err == nil { + t.Errorf("ReadMIMEHeader(%q) = %v, %v; want nil, err", input, m, err) + } + } +} + // Test that continued lines are properly trimmed. Issue 11204. func TestReadMIMEHeaderTrimContinued(t *testing.T) { // In this header, \n and \r\n terminated lines are mixed on purpose. diff --git a/libgo/go/net/udpsock.go b/libgo/go/net/udpsock.go index 2c0f74fdabd..158265f06f8 100644 --- a/libgo/go/net/udpsock.go +++ b/libgo/go/net/udpsock.go @@ -9,7 +9,7 @@ import ( "syscall" ) -// BUG(mikio): On NaCl, Plan 9 and Windows, the ReadMsgUDP and +// BUG(mikio): On NaCl and Plan 9, the ReadMsgUDP and // WriteMsgUDP methods of UDPConn are not implemented. // BUG(mikio): On Windows, the File method of UDPConn is not diff --git a/libgo/go/net/udpsock_test.go b/libgo/go/net/udpsock_test.go index 6d4974e3e49..4ae014c01d9 100644 --- a/libgo/go/net/udpsock_test.go +++ b/libgo/go/net/udpsock_test.go @@ -161,7 +161,7 @@ func testWriteToConn(t *testing.T, raddr string) { } _, _, err = c.(*UDPConn).WriteMsgUDP(b, nil, nil) switch runtime.GOOS { - case "nacl", "windows": // see golang.org/issue/9252 + case "nacl": // see golang.org/issue/9252 t.Skipf("not implemented yet on %s", runtime.GOOS) default: if err != nil { @@ -204,7 +204,7 @@ func testWriteToPacketConn(t *testing.T, raddr string) { } _, _, err = c.(*UDPConn).WriteMsgUDP(b, nil, ra) switch runtime.GOOS { - case "nacl", "windows": // see golang.org/issue/9252 + case "nacl": // see golang.org/issue/9252 t.Skipf("not implemented yet on %s", runtime.GOOS) default: if err != nil { diff --git a/libgo/go/net/unixsock.go b/libgo/go/net/unixsock.go index 057940acf65..20326dabeaa 100644 --- a/libgo/go/net/unixsock.go +++ b/libgo/go/net/unixsock.go @@ -219,6 +219,18 @@ type UnixListener struct { func (ln *UnixListener) ok() bool { return ln != nil && ln.fd != nil } +// SyscallConn returns a raw network connection. +// This implements the syscall.Conn interface. +// +// The returned RawConn only supports calling Control. Read and +// Write return an error. +func (l *UnixListener) SyscallConn() (syscall.RawConn, error) { + if !l.ok() { + return nil, syscall.EINVAL + } + return newRawListener(l.fd) +} + // AcceptUnix accepts the next incoming call and returns the new // connection. func (l *UnixListener) AcceptUnix() (*UnixConn, error) { diff --git a/libgo/go/net/unixsock_linux_test.go b/libgo/go/net/unixsock_linux_test.go new file mode 100644 index 00000000000..d04007cef38 --- /dev/null +++ b/libgo/go/net/unixsock_linux_test.go @@ -0,0 +1,104 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "bytes" + "reflect" + "syscall" + "testing" + "time" +) + +func TestUnixgramAutobind(t *testing.T) { + laddr := &UnixAddr{Name: "", Net: "unixgram"} + c1, err := ListenUnixgram("unixgram", laddr) + if err != nil { + t.Fatal(err) + } + defer c1.Close() + + // retrieve the autobind address + autoAddr := c1.LocalAddr().(*UnixAddr) + if len(autoAddr.Name) <= 1 { + t.Fatalf("invalid autobind address: %v", autoAddr) + } + if autoAddr.Name[0] != '@' { + t.Fatalf("invalid autobind address: %v", autoAddr) + } + + c2, err := DialUnix("unixgram", nil, autoAddr) + if err != nil { + t.Fatal(err) + } + defer c2.Close() + + if !reflect.DeepEqual(c1.LocalAddr(), c2.RemoteAddr()) { + t.Fatalf("expected autobind address %v, got %v", c1.LocalAddr(), c2.RemoteAddr()) + } +} + +func TestUnixAutobindClose(t *testing.T) { + laddr := &UnixAddr{Name: "", Net: "unix"} + ln, err := ListenUnix("unix", laddr) + if err != nil { + t.Fatal(err) + } + ln.Close() +} + +func TestUnixgramLinuxAbstractLongName(t *testing.T) { + if !testableNetwork("unixgram") { + t.Skip("abstract unix socket long name test") + } + + // Create an abstract socket name whose length is exactly + // the maximum RawSockkaddrUnix Path len + rsu := syscall.RawSockaddrUnix{} + addrBytes := make([]byte, len(rsu.Path)) + copy(addrBytes, "@abstract_test") + addr := string(addrBytes) + + la, err := ResolveUnixAddr("unixgram", addr) + if err != nil { + t.Fatal(err) + } + c, err := ListenUnixgram("unixgram", la) + if err != nil { + t.Fatal(err) + } + defer c.Close() + + off := make(chan bool) + data := [5]byte{1, 2, 3, 4, 5} + go func() { + defer func() { off <- true }() + s, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_DGRAM, 0) + if err != nil { + t.Error(err) + return + } + defer syscall.Close(s) + rsa := &syscall.SockaddrUnix{Name: addr} + if err := syscall.Sendto(s, data[:], 0, rsa); err != nil { + t.Error(err) + return + } + }() + + <-off + b := make([]byte, 64) + c.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) + n, from, err := c.ReadFrom(b) + if err != nil { + t.Fatal(err) + } + if from != nil { + t.Fatalf("unexpected peer address: %v", from) + } + if !bytes.Equal(b[:n], data[:]) { + t.Fatalf("got %v; want %v", b[:n], data[:]) + } +} diff --git a/libgo/go/net/unixsock_test.go b/libgo/go/net/unixsock_test.go index 489a29bc7d7..3e5c8bc3769 100644 --- a/libgo/go/net/unixsock_test.go +++ b/libgo/go/net/unixsock_test.go @@ -170,51 +170,6 @@ func TestUnixgramZeroByteBuffer(t *testing.T) { } } -func TestUnixgramAutobind(t *testing.T) { - if runtime.GOOS != "linux" { - t.Skip("autobind is linux only") - } - - laddr := &UnixAddr{Name: "", Net: "unixgram"} - c1, err := ListenUnixgram("unixgram", laddr) - if err != nil { - t.Fatal(err) - } - defer c1.Close() - - // retrieve the autobind address - autoAddr := c1.LocalAddr().(*UnixAddr) - if len(autoAddr.Name) <= 1 { - t.Fatalf("invalid autobind address: %v", autoAddr) - } - if autoAddr.Name[0] != '@' { - t.Fatalf("invalid autobind address: %v", autoAddr) - } - - c2, err := DialUnix("unixgram", nil, autoAddr) - if err != nil { - t.Fatal(err) - } - defer c2.Close() - - if !reflect.DeepEqual(c1.LocalAddr(), c2.RemoteAddr()) { - t.Fatalf("expected autobind address %v, got %v", c1.LocalAddr(), c2.RemoteAddr()) - } -} - -func TestUnixAutobindClose(t *testing.T) { - if runtime.GOOS != "linux" { - t.Skip("autobind is linux only") - } - - laddr := &UnixAddr{Name: "", Net: "unix"} - ln, err := ListenUnix("unix", laddr) - if err != nil { - t.Fatal(err) - } - ln.Close() -} - func TestUnixgramWrite(t *testing.T) { if !testableNetwork("unixgram") { t.Skip("unixgram test") diff --git a/libgo/go/net/url/url.go b/libgo/go/net/url/url.go index 2ac24725692..7c3d24493e0 100644 --- a/libgo/go/net/url/url.go +++ b/libgo/go/net/url/url.go @@ -163,18 +163,23 @@ func shouldEscape(c byte, mode encoding) bool { return true } -// QueryUnescape does the inverse transformation of QueryEscape, converting -// %AB into the byte 0xAB and '+' into ' ' (space). It returns an error if -// any % is not followed by two hexadecimal digits. +// QueryUnescape does the inverse transformation of QueryEscape, +// converting each 3-byte encoded substring of the form "%AB" into the +// hex-decoded byte 0xAB. It also converts '+' into ' ' (space). +// It returns an error if any % is not followed by two hexadecimal +// digits. func QueryUnescape(s string) (string, error) { return unescape(s, encodeQueryComponent) } -// PathUnescape does the inverse transformation of PathEscape, converting -// %AB into the byte 0xAB. It returns an error if any % is not followed by -// two hexadecimal digits. +// PathUnescape does the inverse transformation of PathEscape, +// converting each 3-byte encoded substring of the form "%AB" into the +// hex-decoded byte 0xAB. It also converts '+' into ' ' (space). +// It returns an error if any % is not followed by two hexadecimal +// digits. // -// PathUnescape is identical to QueryUnescape except that it does not unescape '+' to ' ' (space). +// PathUnescape is identical to QueryUnescape except that it does not +// unescape '+' to ' ' (space). func PathUnescape(s string) (string, error) { return unescape(s, encodePathSegment) } @@ -367,17 +372,26 @@ type Userinfo struct { // Username returns the username. func (u *Userinfo) Username() string { + if u == nil { + return "" + } return u.username } // Password returns the password in case it is set, and whether it is set. func (u *Userinfo) Password() (string, bool) { + if u == nil { + return "", false + } return u.password, u.passwordSet } // String returns the encoded userinfo information in the standard form // of "username[:password]". func (u *Userinfo) String() string { + if u == nil { + return "" + } s := escape(u.username, encodeUserPassword) if u.passwordSet { s += ":" + escape(u.password, encodeUserPassword) @@ -427,7 +441,11 @@ func split(s string, c string, cutc bool) (string, string) { } // Parse parses rawurl into a URL structure. -// The rawurl may be relative or absolute. +// +// The rawurl may be relative (a path, without a host) or absolute +// (starting with a scheme). Trying to parse a hostname and path +// without a scheme is invalid but may not necessarily return an +// error, due to parsing ambiguities. func Parse(rawurl string) (*URL, error) { // Cut off #frag u, frag := split(rawurl, "#", true) @@ -726,7 +744,9 @@ func (u *URL) String() string { buf.WriteString(u.Opaque) } else { if u.Scheme != "" || u.Host != "" || u.User != nil { - buf.WriteString("//") + if u.Host != "" || u.Path != "" || u.User != nil { + buf.WriteString("//") + } if ui := u.User; ui != nil { buf.WriteString(ui.String()) buf.WriteByte('@') @@ -909,7 +929,7 @@ func resolvePath(base, ref string) string { // Add final slash to the joined path. dst = append(dst, "") } - return "/" + strings.TrimLeft(strings.Join(dst, "/"), "/") + return "/" + strings.TrimPrefix(strings.Join(dst, "/"), "/") } // IsAbs reports whether the URL is absolute. @@ -953,12 +973,10 @@ func (u *URL) ResolveReference(ref *URL) *URL { url.Path = "" return &url } - if ref.Path == "" { - if ref.RawQuery == "" { - url.RawQuery = u.RawQuery - if ref.Fragment == "" { - url.Fragment = u.Fragment - } + if ref.Path == "" && ref.RawQuery == "" { + url.RawQuery = u.RawQuery + if ref.Fragment == "" { + url.Fragment = u.Fragment } } // The "abs_path" or "rel_path" cases. diff --git a/libgo/go/net/url/url_test.go b/libgo/go/net/url/url_test.go index 6c3bb21d20c..d6aed3acafa 100644 --- a/libgo/go/net/url/url_test.go +++ b/libgo/go/net/url/url_test.go @@ -568,6 +568,28 @@ var urltests = []URLTest{ }, "", }, + // test we can roundtrip magnet url + // fix issue https://golang.org/issue/20054 + { + "magnet:?xt=urn:btih:c12fe1c06bba254a9dc9f519b335aa7c1367a88a&dn", + &URL{ + Scheme: "magnet", + Host: "", + Path: "", + RawQuery: "xt=urn:btih:c12fe1c06bba254a9dc9f519b335aa7c1367a88a&dn", + }, + "magnet:?xt=urn:btih:c12fe1c06bba254a9dc9f519b335aa7c1367a88a&dn", + }, + { + "mailto:?subject=hi", + &URL{ + Scheme: "mailto", + Host: "", + Path: "", + RawQuery: "subject=hi", + }, + "mailto:?subject=hi", + }, } // more useful string for debugging than fmt's struct printer @@ -1010,6 +1032,10 @@ var resolveReferenceTests = []struct { {"http://foo.com/bar?a=b", "/baz?", "http://foo.com/baz?"}, {"http://foo.com/bar?a=b", "/baz?c=d", "http://foo.com/baz?c=d"}, + // Multiple slashes + {"http://foo.com/bar", "http://foo.com//baz", "http://foo.com//baz"}, + {"http://foo.com/bar", "http://foo.com///baz/quux", "http://foo.com///baz/quux"}, + // Scheme-relative {"https://foo.com/bar?a=b", "//bar.com/quux", "https://bar.com/quux"}, @@ -1683,3 +1709,29 @@ func TestGob(t *testing.T) { t.Errorf("json decoded to: %s\nwant: %s\n", u1, u) } } + +func TestNilUser(t *testing.T) { + defer func() { + if v := recover(); v != nil { + t.Fatalf("unexpected panic: %v", v) + } + }() + + u, err := Parse("http://foo.com/") + + if err != nil { + t.Fatalf("parse err: %v", err) + } + + if v := u.User.Username(); v != "" { + t.Fatalf("expected empty username, got %s", v) + } + + if v, ok := u.User.Password(); v != "" || ok { + t.Fatalf("expected empty password, got %s (%v)", v, ok) + } + + if v := u.User.String(); v != "" { + t.Fatalf("expected empty string, got %s", v) + } +} diff --git a/libgo/go/net/write_unix_test.go b/libgo/go/net/write_unix_test.go new file mode 100644 index 00000000000..6d8cb6a6f81 --- /dev/null +++ b/libgo/go/net/write_unix_test.go @@ -0,0 +1,66 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build darwin dragonfly freebsd linux netbsd openbsd solaris + +package net + +import ( + "bytes" + "syscall" + "testing" + "time" +) + +// Test that a client can't trigger an endless loop of write system +// calls on the server by shutting down the write side on the client. +// Possibility raised in the discussion of https://golang.org/cl/71973. +func TestEndlessWrite(t *testing.T) { + t.Parallel() + c := make(chan bool) + server := func(cs *TCPConn) error { + cs.CloseWrite() + <-c + return nil + } + client := func(ss *TCPConn) error { + // Tell the server to return when we return. + defer close(c) + + // Loop writing to the server. The server is not reading + // anything, so this will eventually block, and then time out. + b := bytes.Repeat([]byte{'a'}, 8192) + cagain := 0 + for { + n, err := ss.conn.fd.pfd.WriteOnce(b) + if n > 0 { + cagain = 0 + } + switch err { + case nil: + case syscall.EAGAIN: + if cagain == 0 { + // We've written enough data to + // start blocking. Set a deadline + // so that we will stop. + ss.SetWriteDeadline(time.Now().Add(5 * time.Millisecond)) + } + cagain++ + if cagain > 20 { + t.Error("looping on EAGAIN") + return nil + } + if err = ss.conn.fd.pfd.WaitWrite(); err != nil { + t.Logf("client WaitWrite: %v", err) + return nil + } + default: + // We expect to eventually get an error. + t.Logf("client WriteOnce: %v", err) + return nil + } + } + } + withTCPConnPair(t, client, server) +} |