summaryrefslogtreecommitdiff
path: root/libgo/go/net/dnsclient_unix.go
diff options
context:
space:
mode:
Diffstat (limited to 'libgo/go/net/dnsclient_unix.go')
-rw-r--r--libgo/go/net/dnsclient_unix.go101
1 files changed, 64 insertions, 37 deletions
diff --git a/libgo/go/net/dnsclient_unix.go b/libgo/go/net/dnsclient_unix.go
index 94dbe95afa3..b3284b8cd76 100644
--- a/libgo/go/net/dnsclient_unix.go
+++ b/libgo/go/net/dnsclient_unix.go
@@ -23,7 +23,13 @@ import (
"sync"
"time"
- "internal/x/net/dns/dnsmessage"
+ "golang.org/x/net/dns/dnsmessage"
+)
+
+const (
+ // to be used as a useTCP parameter to exchange
+ useTCPOnly = true
+ useUDPOrTCP = false
)
var (
@@ -131,13 +137,19 @@ func dnsStreamRoundTrip(c Conn, id uint16, query dnsmessage.Question, b []byte)
}
// exchange sends a query on the connection and hopes for a response.
-func (r *Resolver) exchange(ctx context.Context, server string, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) {
+func (r *Resolver) exchange(ctx context.Context, server string, q dnsmessage.Question, timeout time.Duration, useTCP bool) (dnsmessage.Parser, dnsmessage.Header, error) {
q.Class = dnsmessage.ClassINET
id, udpReq, tcpReq, err := newRequest(q)
if err != nil {
return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotMarshalDNSMessage
}
- for _, network := range []string{"udp", "tcp"} {
+ var networks []string
+ if useTCP {
+ networks = []string{"tcp"}
+ } else {
+ networks = []string{"udp", "tcp"}
+ }
+ for _, network := range networks {
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout))
defer cancel()
@@ -171,7 +183,7 @@ func (r *Resolver) exchange(ctx context.Context, server string, q dnsmessage.Que
}
// checkHeader performs basic sanity checks on the header.
-func checkHeader(p *dnsmessage.Parser, h dnsmessage.Header, name, server string) error {
+func checkHeader(p *dnsmessage.Parser, h dnsmessage.Header) error {
if h.RCode == dnsmessage.RCodeNameError {
return errNoSuchHost
}
@@ -202,7 +214,7 @@ func checkHeader(p *dnsmessage.Parser, h dnsmessage.Header, name, server string)
return nil
}
-func skipToAnswer(p *dnsmessage.Parser, qtype dnsmessage.Type, name, server string) error {
+func skipToAnswer(p *dnsmessage.Parser, qtype dnsmessage.Type) error {
for {
h, err := p.AnswerHeader()
if err == dnsmessage.ErrSectionDone {
@@ -241,7 +253,7 @@ func (r *Resolver) tryOneName(ctx context.Context, cfg *dnsConfig, name string,
for j := uint32(0); j < sLen; j++ {
server := cfg.servers[(serverOffset+j)%sLen]
- p, h, err := r.exchange(ctx, server, q, cfg.timeout)
+ p, h, err := r.exchange(ctx, server, q, cfg.timeout, cfg.useTCP)
if err != nil {
dnsErr := &DNSError{
Err: err.Error(),
@@ -260,7 +272,7 @@ func (r *Resolver) tryOneName(ctx context.Context, cfg *dnsConfig, name string,
continue
}
- if err := checkHeader(&p, h, name, server); err != nil {
+ if err := checkHeader(&p, h); err != nil {
dnsErr := &DNSError{
Err: err.Error(),
Name: name,
@@ -272,17 +284,15 @@ func (r *Resolver) tryOneName(ctx context.Context, cfg *dnsConfig, name string,
if err == errNoSuchHost {
// The name does not exist, so trying
// another server won't help.
- //
- // TODO: indicate this in a more
- // obvious way, such as a field on
- // DNSError?
+
+ dnsErr.IsNotFound = true
return p, server, dnsErr
}
lastErr = dnsErr
continue
}
- err = skipToAnswer(&p, qtype, name, server)
+ err = skipToAnswer(&p, qtype)
if err == nil {
return p, server, nil
}
@@ -294,9 +304,8 @@ func (r *Resolver) tryOneName(ctx context.Context, cfg *dnsConfig, name string,
if err == errNoSuchHost {
// The name does not exist, so trying another
// server won't help.
- //
- // TODO: indicate this in a more obvious way,
- // such as a field on DNSError?
+
+ lastErr.(*DNSError).IsNotFound = true
return p, server, lastErr
}
}
@@ -386,7 +395,7 @@ func (r *Resolver) lookup(ctx context.Context, name string, qtype dnsmessage.Typ
// Other lookups might allow broader name syntax
// (for example Multicast DNS allows UTF-8; see RFC 6762).
// For consistency with libc resolvers, report no such host.
- return dnsmessage.Parser{}, "", &DNSError{Err: errNoSuchHost.Error(), Name: name}
+ return dnsmessage.Parser{}, "", &DNSError{Err: errNoSuchHost.Error(), Name: name, IsNotFound: true}
}
resolvConf.tryUpdate("/etc/resolv.conf")
resolvConf.mu.RLock()
@@ -563,40 +572,58 @@ func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, name string, order
}
if !isDomainName(name) {
// See comment in func lookup above about use of errNoSuchHost.
- return nil, dnsmessage.Name{}, &DNSError{Err: errNoSuchHost.Error(), Name: name}
+ return nil, dnsmessage.Name{}, &DNSError{Err: errNoSuchHost.Error(), Name: name, IsNotFound: true}
}
resolvConf.tryUpdate("/etc/resolv.conf")
resolvConf.mu.RLock()
conf := resolvConf.dnsConfig
resolvConf.mu.RUnlock()
- type racer struct {
+ type result struct {
p dnsmessage.Parser
server string
error
}
- lane := make(chan racer, 1)
+ lane := make(chan result, 1)
qtypes := [...]dnsmessage.Type{dnsmessage.TypeA, dnsmessage.TypeAAAA}
- var lastErr error
- for _, fqdn := range conf.nameList(name) {
- for _, qtype := range qtypes {
+ var queryFn func(fqdn string, qtype dnsmessage.Type)
+ var responseFn func(fqdn string, qtype dnsmessage.Type) result
+ if conf.singleRequest {
+ queryFn = func(fqdn string, qtype dnsmessage.Type) {}
+ responseFn = func(fqdn string, qtype dnsmessage.Type) result {
+ dnsWaitGroup.Add(1)
+ defer dnsWaitGroup.Done()
+ p, server, err := r.tryOneName(ctx, conf, fqdn, qtype)
+ return result{p, server, err}
+ }
+ } else {
+ queryFn = func(fqdn string, qtype dnsmessage.Type) {
dnsWaitGroup.Add(1)
go func(qtype dnsmessage.Type) {
p, server, err := r.tryOneName(ctx, conf, fqdn, qtype)
- lane <- racer{p, server, err}
+ lane <- result{p, server, err}
dnsWaitGroup.Done()
}(qtype)
}
+ responseFn = func(fqdn string, qtype dnsmessage.Type) result {
+ return <-lane
+ }
+ }
+ var lastErr error
+ for _, fqdn := range conf.nameList(name) {
+ for _, qtype := range qtypes {
+ queryFn(fqdn, qtype)
+ }
hitStrictError := false
- for range qtypes {
- racer := <-lane
- if racer.error != nil {
- if nerr, ok := racer.error.(Error); ok && nerr.Temporary() && r.strictErrors() {
+ for _, qtype := range qtypes {
+ result := responseFn(fqdn, qtype)
+ if result.error != nil {
+ if nerr, ok := result.error.(Error); ok && nerr.Temporary() && r.strictErrors() {
// This error will abort the nameList loop.
hitStrictError = true
- lastErr = racer.error
+ lastErr = result.error
} else if lastErr == nil || fqdn == name+"." {
// Prefer error for original name.
- lastErr = racer.error
+ lastErr = result.error
}
continue
}
@@ -618,12 +645,12 @@ func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, name string, order
loop:
for {
- h, err := racer.p.AnswerHeader()
+ h, err := result.p.AnswerHeader()
if err != nil && err != dnsmessage.ErrSectionDone {
lastErr = &DNSError{
Err: "cannot marshal DNS message",
Name: name,
- Server: racer.server,
+ Server: result.server,
}
}
if err != nil {
@@ -631,35 +658,35 @@ func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, name string, order
}
switch h.Type {
case dnsmessage.TypeA:
- a, err := racer.p.AResource()
+ a, err := result.p.AResource()
if err != nil {
lastErr = &DNSError{
Err: "cannot marshal DNS message",
Name: name,
- Server: racer.server,
+ Server: result.server,
}
break loop
}
addrs = append(addrs, IPAddr{IP: IP(a.A[:])})
case dnsmessage.TypeAAAA:
- aaaa, err := racer.p.AAAAResource()
+ aaaa, err := result.p.AAAAResource()
if err != nil {
lastErr = &DNSError{
Err: "cannot marshal DNS message",
Name: name,
- Server: racer.server,
+ Server: result.server,
}
break loop
}
addrs = append(addrs, IPAddr{IP: IP(aaaa.AAAA[:])})
default:
- if err := racer.p.SkipAnswer(); err != nil {
+ if err := result.p.SkipAnswer(); err != nil {
lastErr = &DNSError{
Err: "cannot marshal DNS message",
Name: name,
- Server: racer.server,
+ Server: result.server,
}
break loop
}