diff options
Diffstat (limited to 'libgo/go/net/dnsclient_unix.go')
-rw-r--r-- | libgo/go/net/dnsclient_unix.go | 101 |
1 files changed, 64 insertions, 37 deletions
diff --git a/libgo/go/net/dnsclient_unix.go b/libgo/go/net/dnsclient_unix.go index 94dbe95afa3..b3284b8cd76 100644 --- a/libgo/go/net/dnsclient_unix.go +++ b/libgo/go/net/dnsclient_unix.go @@ -23,7 +23,13 @@ import ( "sync" "time" - "internal/x/net/dns/dnsmessage" + "golang.org/x/net/dns/dnsmessage" +) + +const ( + // to be used as a useTCP parameter to exchange + useTCPOnly = true + useUDPOrTCP = false ) var ( @@ -131,13 +137,19 @@ func dnsStreamRoundTrip(c Conn, id uint16, query dnsmessage.Question, b []byte) } // exchange sends a query on the connection and hopes for a response. -func (r *Resolver) exchange(ctx context.Context, server string, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) { +func (r *Resolver) exchange(ctx context.Context, server string, q dnsmessage.Question, timeout time.Duration, useTCP bool) (dnsmessage.Parser, dnsmessage.Header, error) { q.Class = dnsmessage.ClassINET id, udpReq, tcpReq, err := newRequest(q) if err != nil { return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotMarshalDNSMessage } - for _, network := range []string{"udp", "tcp"} { + var networks []string + if useTCP { + networks = []string{"tcp"} + } else { + networks = []string{"udp", "tcp"} + } + for _, network := range networks { ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout)) defer cancel() @@ -171,7 +183,7 @@ func (r *Resolver) exchange(ctx context.Context, server string, q dnsmessage.Que } // checkHeader performs basic sanity checks on the header. -func checkHeader(p *dnsmessage.Parser, h dnsmessage.Header, name, server string) error { +func checkHeader(p *dnsmessage.Parser, h dnsmessage.Header) error { if h.RCode == dnsmessage.RCodeNameError { return errNoSuchHost } @@ -202,7 +214,7 @@ func checkHeader(p *dnsmessage.Parser, h dnsmessage.Header, name, server string) return nil } -func skipToAnswer(p *dnsmessage.Parser, qtype dnsmessage.Type, name, server string) error { +func skipToAnswer(p *dnsmessage.Parser, qtype dnsmessage.Type) error { for { h, err := p.AnswerHeader() if err == dnsmessage.ErrSectionDone { @@ -241,7 +253,7 @@ func (r *Resolver) tryOneName(ctx context.Context, cfg *dnsConfig, name string, for j := uint32(0); j < sLen; j++ { server := cfg.servers[(serverOffset+j)%sLen] - p, h, err := r.exchange(ctx, server, q, cfg.timeout) + p, h, err := r.exchange(ctx, server, q, cfg.timeout, cfg.useTCP) if err != nil { dnsErr := &DNSError{ Err: err.Error(), @@ -260,7 +272,7 @@ func (r *Resolver) tryOneName(ctx context.Context, cfg *dnsConfig, name string, continue } - if err := checkHeader(&p, h, name, server); err != nil { + if err := checkHeader(&p, h); err != nil { dnsErr := &DNSError{ Err: err.Error(), Name: name, @@ -272,17 +284,15 @@ func (r *Resolver) tryOneName(ctx context.Context, cfg *dnsConfig, name string, if err == errNoSuchHost { // The name does not exist, so trying // another server won't help. - // - // TODO: indicate this in a more - // obvious way, such as a field on - // DNSError? + + dnsErr.IsNotFound = true return p, server, dnsErr } lastErr = dnsErr continue } - err = skipToAnswer(&p, qtype, name, server) + err = skipToAnswer(&p, qtype) if err == nil { return p, server, nil } @@ -294,9 +304,8 @@ func (r *Resolver) tryOneName(ctx context.Context, cfg *dnsConfig, name string, if err == errNoSuchHost { // The name does not exist, so trying another // server won't help. - // - // TODO: indicate this in a more obvious way, - // such as a field on DNSError? + + lastErr.(*DNSError).IsNotFound = true return p, server, lastErr } } @@ -386,7 +395,7 @@ func (r *Resolver) lookup(ctx context.Context, name string, qtype dnsmessage.Typ // Other lookups might allow broader name syntax // (for example Multicast DNS allows UTF-8; see RFC 6762). // For consistency with libc resolvers, report no such host. - return dnsmessage.Parser{}, "", &DNSError{Err: errNoSuchHost.Error(), Name: name} + return dnsmessage.Parser{}, "", &DNSError{Err: errNoSuchHost.Error(), Name: name, IsNotFound: true} } resolvConf.tryUpdate("/etc/resolv.conf") resolvConf.mu.RLock() @@ -563,40 +572,58 @@ func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, name string, order } if !isDomainName(name) { // See comment in func lookup above about use of errNoSuchHost. - return nil, dnsmessage.Name{}, &DNSError{Err: errNoSuchHost.Error(), Name: name} + return nil, dnsmessage.Name{}, &DNSError{Err: errNoSuchHost.Error(), Name: name, IsNotFound: true} } resolvConf.tryUpdate("/etc/resolv.conf") resolvConf.mu.RLock() conf := resolvConf.dnsConfig resolvConf.mu.RUnlock() - type racer struct { + type result struct { p dnsmessage.Parser server string error } - lane := make(chan racer, 1) + lane := make(chan result, 1) qtypes := [...]dnsmessage.Type{dnsmessage.TypeA, dnsmessage.TypeAAAA} - var lastErr error - for _, fqdn := range conf.nameList(name) { - for _, qtype := range qtypes { + var queryFn func(fqdn string, qtype dnsmessage.Type) + var responseFn func(fqdn string, qtype dnsmessage.Type) result + if conf.singleRequest { + queryFn = func(fqdn string, qtype dnsmessage.Type) {} + responseFn = func(fqdn string, qtype dnsmessage.Type) result { + dnsWaitGroup.Add(1) + defer dnsWaitGroup.Done() + p, server, err := r.tryOneName(ctx, conf, fqdn, qtype) + return result{p, server, err} + } + } else { + queryFn = func(fqdn string, qtype dnsmessage.Type) { dnsWaitGroup.Add(1) go func(qtype dnsmessage.Type) { p, server, err := r.tryOneName(ctx, conf, fqdn, qtype) - lane <- racer{p, server, err} + lane <- result{p, server, err} dnsWaitGroup.Done() }(qtype) } + responseFn = func(fqdn string, qtype dnsmessage.Type) result { + return <-lane + } + } + var lastErr error + for _, fqdn := range conf.nameList(name) { + for _, qtype := range qtypes { + queryFn(fqdn, qtype) + } hitStrictError := false - for range qtypes { - racer := <-lane - if racer.error != nil { - if nerr, ok := racer.error.(Error); ok && nerr.Temporary() && r.strictErrors() { + for _, qtype := range qtypes { + result := responseFn(fqdn, qtype) + if result.error != nil { + if nerr, ok := result.error.(Error); ok && nerr.Temporary() && r.strictErrors() { // This error will abort the nameList loop. hitStrictError = true - lastErr = racer.error + lastErr = result.error } else if lastErr == nil || fqdn == name+"." { // Prefer error for original name. - lastErr = racer.error + lastErr = result.error } continue } @@ -618,12 +645,12 @@ func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, name string, order loop: for { - h, err := racer.p.AnswerHeader() + h, err := result.p.AnswerHeader() if err != nil && err != dnsmessage.ErrSectionDone { lastErr = &DNSError{ Err: "cannot marshal DNS message", Name: name, - Server: racer.server, + Server: result.server, } } if err != nil { @@ -631,35 +658,35 @@ func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, name string, order } switch h.Type { case dnsmessage.TypeA: - a, err := racer.p.AResource() + a, err := result.p.AResource() if err != nil { lastErr = &DNSError{ Err: "cannot marshal DNS message", Name: name, - Server: racer.server, + Server: result.server, } break loop } addrs = append(addrs, IPAddr{IP: IP(a.A[:])}) case dnsmessage.TypeAAAA: - aaaa, err := racer.p.AAAAResource() + aaaa, err := result.p.AAAAResource() if err != nil { lastErr = &DNSError{ Err: "cannot marshal DNS message", Name: name, - Server: racer.server, + Server: result.server, } break loop } addrs = append(addrs, IPAddr{IP: IP(aaaa.AAAA[:])}) default: - if err := racer.p.SkipAnswer(); err != nil { + if err := result.p.SkipAnswer(); err != nil { lastErr = &DNSError{ Err: "cannot marshal DNS message", Name: name, - Server: racer.server, + Server: result.server, } break loop } |