util/dnsname: add FQDN type, use throughout codebase.

Signed-off-by: David Anderson <danderson@tailscale.com>
This commit is contained in:
David Anderson
2021-04-09 15:24:47 -07:00
committed by Dave Anderson
parent 7a1813fd24
commit 1a371b93be
13 changed files with 449 additions and 214 deletions

View File

@@ -9,7 +9,6 @@ package resolver
import (
"encoding/hex"
"errors"
"fmt"
"sort"
"strings"
"sync"
@@ -59,12 +58,12 @@ type Config struct {
// queries within that suffix.
// Queries only match the most specific suffix.
// To register a "default route", add an entry for ".".
Routes map[string][]netaddr.IPPort
Routes map[dnsname.FQDN][]netaddr.IPPort
// LocalHosts is a map of FQDNs to corresponding IPs.
Hosts map[string][]netaddr.IP
Hosts map[dnsname.FQDN][]netaddr.IP
// LocalDomains is a list of DNS name suffixes that should not be
// routed to upstream resolvers.
LocalDomains []string
LocalDomains []dnsname.FQDN
}
// Resolver is a DNS resolver for nodes on the Tailscale network,
@@ -92,9 +91,9 @@ type Resolver struct {
// mu guards the following fields from being updated while used.
mu sync.Mutex
localDomains []string
hostToIP map[string][]netaddr.IP
ipToHost map[netaddr.IP]string
localDomains []dnsname.FQDN
hostToIP map[dnsname.FQDN][]netaddr.IP
ipToHost map[netaddr.IP]dnsname.FQDN
}
// New returns a new resolver.
@@ -107,8 +106,8 @@ func New(logf logger.Logf, linkMon *monitor.Mon) *Resolver {
responses: make(chan packet),
errors: make(chan error),
closed: make(chan struct{}),
hostToIP: map[string][]netaddr.IP{},
ipToHost: map[netaddr.IP]string{},
hostToIP: map[dnsname.FQDN][]netaddr.IP{},
ipToHost: map[netaddr.IP]dnsname.FQDN{},
}
r.forwarder = newForwarder(r.logf, r.responses)
if r.linkMon != nil {
@@ -121,10 +120,6 @@ func New(logf logger.Logf, linkMon *monitor.Mon) *Resolver {
return r
}
func isFQDN(s string) bool {
return strings.HasSuffix(s, ".")
}
func (r *Resolver) TestOnlySetHook(hook func(Config)) { r.saveConfigForTests = hook }
func (r *Resolver) SetConfig(cfg Config) error {
@@ -133,26 +128,15 @@ func (r *Resolver) SetConfig(cfg Config) error {
}
routes := make([]route, 0, len(cfg.Routes))
reverse := make(map[netaddr.IP]string, len(cfg.Hosts))
reverse := make(map[netaddr.IP]dnsname.FQDN, len(cfg.Hosts))
for host, ips := range cfg.Hosts {
if !isFQDN(host) {
return fmt.Errorf("host entry %q is not a FQDN", host)
}
for _, ip := range ips {
reverse[ip] = host
}
}
for _, domain := range cfg.LocalDomains {
if !isFQDN(domain) {
return fmt.Errorf("local domain %q is not a FQDN", domain)
}
}
for suffix, ips := range cfg.Routes {
if !strings.HasSuffix(suffix, ".") {
return fmt.Errorf("route suffix %q is not a FQDN", suffix)
}
routes = append(routes, route{
suffix: suffix,
resolvers: ips,
@@ -160,7 +144,7 @@ func (r *Resolver) SetConfig(cfg Config) error {
}
// Sort from longest prefix to shortest.
sort.Slice(routes, func(i, j int) bool {
return dnsname.NumLabels(routes[i].suffix) > dnsname.NumLabels(routes[j].suffix)
return routes[i].suffix.NumLabels() > routes[j].suffix.NumLabels()
})
r.forwarder.setRoutes(routes)
@@ -229,12 +213,11 @@ func (r *Resolver) NextResponse() (packet []byte, to netaddr.IPPort, err error)
// resolveLocal returns an IP for the given domain, if domain is in
// the local hosts map and has an IP corresponding to the requested
// typ (A, AAAA, ALL).
// The domain name must be in canonical form (with a trailing period).
// Returns dns.RCodeRefused to indicate that the local map is not
// authoritative for domain.
func (r *Resolver) resolveLocal(domain string, typ dns.Type) (netaddr.IP, dns.RCode) {
func (r *Resolver) resolveLocal(domain dnsname.FQDN, typ dns.Type) (netaddr.IP, dns.RCode) {
// Reject .onion domains per RFC 7686.
if dnsname.HasSuffix(domain, ".onion") {
if dnsname.HasSuffix(domain.WithoutTrailingDot(), ".onion") {
return netaddr.IP{}, dns.RCodeNameError
}
@@ -246,7 +229,7 @@ func (r *Resolver) resolveLocal(domain string, typ dns.Type) (netaddr.IP, dns.RC
addrs, found := hosts[domain]
if !found {
for _, suffix := range localDomains {
if dnsname.HasSuffix(domain, suffix) {
if suffix.Contains(domain) {
// We are authoritative for the queried domain.
return netaddr.IP{}, dns.RCodeNameError
}
@@ -304,8 +287,7 @@ func (r *Resolver) resolveLocal(domain string, typ dns.Type) (netaddr.IP, dns.RC
}
// resolveReverse returns the unique domain name that maps to the given address.
// The returned domain name is in canonical form (with a trailing period).
func (r *Resolver) resolveLocalReverse(ip netaddr.IP) (string, dns.RCode) {
func (r *Resolver) resolveLocalReverse(ip netaddr.IP) (dnsname.FQDN, dns.RCode) {
r.mu.Lock()
ips := r.ipToHost
r.mu.Unlock()
@@ -362,7 +344,7 @@ type response struct {
Header dns.Header
Question dns.Question
// Name is the response to a PTR query.
Name string
Name dnsname.FQDN
// IP is the response to an A, AAAA, or ALL query.
IP netaddr.IP
}
@@ -425,7 +407,7 @@ func marshalAAAARecord(name dns.Name, ip netaddr.IP, builder *dns.Builder) error
// marshalPTRRecord serializes a PTR record into an active builder.
// The caller may continue using the builder following the call.
func marshalPTRRecord(queryName dns.Name, name string, builder *dns.Builder) error {
func marshalPTRRecord(queryName dns.Name, name dnsname.FQDN, builder *dns.Builder) error {
var answer dns.PTRResource
var err error
@@ -435,7 +417,7 @@ func marshalPTRRecord(queryName dns.Name, name string, builder *dns.Builder) err
Class: dns.ClassINET,
TTL: uint32(defaultTTL / time.Second),
}
answer.PTR, err = dns.NewName(name)
answer.PTR, err = dns.NewName(name.WithTrailingDot())
if err != nil {
return err
}
@@ -508,12 +490,13 @@ const (
// r._dns-sd._udp.<domain>.
// dr._dns-sd._udp.<domain>.
// lb._dns-sd._udp.<domain>.
func hasRDNSBonjourPrefix(s string) bool {
func hasRDNSBonjourPrefix(name dnsname.FQDN) bool {
// Even the shortest name containing a Bonjour prefix is long,
// so check length (cheap) and bail early if possible.
if len(s) < len("*._dns-sd._udp.0.0.0.0.in-addr.arpa.") {
if len(name) < len("*._dns-sd._udp.0.0.0.0.in-addr.arpa.") {
return false
}
s := name.WithTrailingDot()
dot := strings.IndexByte(s, '.')
if dot == -1 {
return false // shouldn't happen
@@ -548,9 +531,9 @@ func rawNameToLower(name []byte) string {
// 4.3.2.1.in-addr.arpa
// is transformed to
// 1.2.3.4
func rdnsNameToIPv4(name string) (ip netaddr.IP, ok bool) {
name = strings.TrimSuffix(name, rdnsv4Suffix)
ip, err := netaddr.ParseIP(string(name))
func rdnsNameToIPv4(name dnsname.FQDN) (ip netaddr.IP, ok bool) {
s := strings.TrimSuffix(name.WithTrailingDot(), rdnsv4Suffix)
ip, err := netaddr.ParseIP(s)
if err != nil {
return netaddr.IP{}, false
}
@@ -567,21 +550,21 @@ func rdnsNameToIPv4(name string) (ip netaddr.IP, ok bool) {
// b.a.9.8.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.
// is transformed to
// 2001:db8::567:89ab
func rdnsNameToIPv6(name string) (ip netaddr.IP, ok bool) {
func rdnsNameToIPv6(name dnsname.FQDN) (ip netaddr.IP, ok bool) {
var b [32]byte
var ipb [16]byte
name = strings.TrimSuffix(name, rdnsv6Suffix)
s := strings.TrimSuffix(name.WithTrailingDot(), rdnsv6Suffix)
// 32 nibbles and 31 dots between them.
if len(name) != 63 {
if len(s) != 63 {
return netaddr.IP{}, false
}
// Dots and hex digits alternate.
prevDot := true
// i ranges over name backward; j ranges over b forward.
for i, j := len(name)-1, 0; i >= 0; i-- {
thisDot := (name[i] == '.')
for i, j := len(s)-1, 0; i >= 0; i-- {
thisDot := (s[i] == '.')
if prevDot == thisDot {
return netaddr.IP{}, false
}
@@ -590,7 +573,7 @@ func rdnsNameToIPv6(name string) (ip netaddr.IP, ok bool) {
if !thisDot {
// This is safe assuming alternation.
// We do not check that non-dots are hex digits: hex.Decode below will do that.
b[j] = name[i]
b[j] = s[i]
j++
}
}
@@ -605,7 +588,7 @@ func rdnsNameToIPv6(name string) (ip netaddr.IP, ok bool) {
// respondReverse returns a DNS response to a PTR query.
// It is assumed that resp.Question is populated by respond before this is called.
func (r *Resolver) respondReverse(query []byte, name string, resp *response) ([]byte, error) {
func (r *Resolver) respondReverse(query []byte, name dnsname.FQDN, resp *response) ([]byte, error) {
if hasRDNSBonjourPrefix(name) {
return nil, errNotOurName
}
@@ -613,9 +596,9 @@ func (r *Resolver) respondReverse(query []byte, name string, resp *response) ([]
var ip netaddr.IP
var ok bool
switch {
case strings.HasSuffix(name, rdnsv4Suffix):
case strings.HasSuffix(name.WithTrailingDot(), rdnsv4Suffix):
ip, ok = rdnsNameToIPv4(name)
case strings.HasSuffix(name, rdnsv6Suffix):
case strings.HasSuffix(name.WithTrailingDot(), rdnsv6Suffix):
ip, ok = rdnsNameToIPv6(name)
default:
return nil, errNotOurName
@@ -656,7 +639,12 @@ func (r *Resolver) respond(query []byte) ([]byte, error) {
return marshalResponse(resp)
}
rawName := resp.Question.Name.Data[:resp.Question.Name.Length]
name := rawNameToLower(rawName)
name, err := dnsname.ToFQDN(rawNameToLower(rawName))
if err != nil {
// DNS packet unexpectedly contains an invalid FQDN.
resp.Header.RCode = dns.RCodeFormatError
return marshalResponse(resp)
}
// Always try to handle reverse lookups; delegate inside when not found.
// This way, queries for existent nodes do not leak,