mirror of
https://github.com/tailscale/tailscale.git
synced 2025-04-20 13:41:41 +00:00
cmd/natc: fix handling of upstream and downstream nxdomain
Ensure that the upstream is always queried, so that if upstream is going to NXDOMAIN natc will also return NXDOMAIN rather than returning address allocations. At this time both IPv4 and IPv6 are still returned if upstream has a result, regardless of upstream support - this is ~ok as we're proxying. Rewrite the tests to be once again slightly closer to integration tests, but they're still very rough and in need of a refactor. Further refactors are probably needed implementation side too, as this removed rather than added units. Updates #15367 Signed-off-by: James Tucker <james@tailscale.com>
This commit is contained in:
parent
fb96137d79
commit
025fe72448
273
cmd/natc/natc.go
273
cmd/natc/natc.go
@ -26,14 +26,15 @@ import (
|
|||||||
"go4.org/netipx"
|
"go4.org/netipx"
|
||||||
"golang.org/x/net/dns/dnsmessage"
|
"golang.org/x/net/dns/dnsmessage"
|
||||||
"tailscale.com/client/local"
|
"tailscale.com/client/local"
|
||||||
|
"tailscale.com/client/tailscale/apitype"
|
||||||
"tailscale.com/cmd/natc/ippool"
|
"tailscale.com/cmd/natc/ippool"
|
||||||
"tailscale.com/envknob"
|
"tailscale.com/envknob"
|
||||||
"tailscale.com/hostinfo"
|
"tailscale.com/hostinfo"
|
||||||
"tailscale.com/ipn"
|
"tailscale.com/ipn"
|
||||||
"tailscale.com/net/netutil"
|
"tailscale.com/net/netutil"
|
||||||
"tailscale.com/tailcfg"
|
|
||||||
"tailscale.com/tsnet"
|
"tailscale.com/tsnet"
|
||||||
"tailscale.com/tsweb"
|
"tailscale.com/tsweb"
|
||||||
|
"tailscale.com/util/mak"
|
||||||
"tailscale.com/util/must"
|
"tailscale.com/util/must"
|
||||||
"tailscale.com/wgengine/netstack"
|
"tailscale.com/wgengine/netstack"
|
||||||
)
|
)
|
||||||
@ -148,14 +149,15 @@ func main() {
|
|||||||
v6ULA := ula(uint16(*siteID))
|
v6ULA := ula(uint16(*siteID))
|
||||||
c := &connector{
|
c := &connector{
|
||||||
ts: ts,
|
ts: ts,
|
||||||
lc: lc,
|
whois: lc,
|
||||||
v6ULA: v6ULA,
|
v6ULA: v6ULA,
|
||||||
ignoreDsts: ignoreDstTable,
|
ignoreDsts: ignoreDstTable,
|
||||||
ipPool: &ippool.IPPool{V6ULA: v6ULA, IPSet: addrPool},
|
ipPool: &ippool.IPPool{V6ULA: v6ULA, IPSet: addrPool},
|
||||||
routes: routes,
|
routes: routes,
|
||||||
dnsAddr: dnsAddr,
|
dnsAddr: dnsAddr,
|
||||||
|
resolver: net.DefaultResolver,
|
||||||
}
|
}
|
||||||
c.run(ctx)
|
c.run(ctx, lc)
|
||||||
}
|
}
|
||||||
|
|
||||||
func calculateAddresses(prefixes []netip.Prefix) (*netipx.IPSet, netip.Addr, *netipx.IPSet) {
|
func calculateAddresses(prefixes []netip.Prefix) (*netipx.IPSet, netip.Addr, *netipx.IPSet) {
|
||||||
@ -170,12 +172,20 @@ func calculateAddresses(prefixes []netip.Prefix) (*netipx.IPSet, netip.Addr, *ne
|
|||||||
return routesToAdvertise, dnsAddr, addrPool
|
return routesToAdvertise, dnsAddr, addrPool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type lookupNetIPer interface {
|
||||||
|
LookupNetIP(ctx context.Context, net, host string) ([]netip.Addr, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type whoiser interface {
|
||||||
|
WhoIs(ctx context.Context, remoteAddr string) (*apitype.WhoIsResponse, error)
|
||||||
|
}
|
||||||
|
|
||||||
type connector struct {
|
type connector struct {
|
||||||
// ts is the tsnet.Server used to host the connector.
|
// ts is the tsnet.Server used to host the connector.
|
||||||
ts *tsnet.Server
|
ts *tsnet.Server
|
||||||
// lc is the local.Client used to interact with the tsnet.Server hosting this
|
// whois is the local.Client used to interact with the tsnet.Server hosting this
|
||||||
// connector.
|
// connector.
|
||||||
lc *local.Client
|
whois whoiser
|
||||||
|
|
||||||
// dnsAddr is the IPv4 address to listen on for DNS requests. It is used to
|
// dnsAddr is the IPv4 address to listen on for DNS requests. It is used to
|
||||||
// prevent the app connector from assigning it to a domain.
|
// prevent the app connector from assigning it to a domain.
|
||||||
@ -197,7 +207,11 @@ type connector struct {
|
|||||||
// natc behavior, which would return a dummy ip address pointing at natc).
|
// natc behavior, which would return a dummy ip address pointing at natc).
|
||||||
ignoreDsts *bart.Table[bool]
|
ignoreDsts *bart.Table[bool]
|
||||||
|
|
||||||
|
// ipPool contains the per-peer IPv4 address assignments.
|
||||||
ipPool *ippool.IPPool
|
ipPool *ippool.IPPool
|
||||||
|
|
||||||
|
// resolver is used to lookup IP addresses for DNS queries.
|
||||||
|
resolver lookupNetIPer
|
||||||
}
|
}
|
||||||
|
|
||||||
// v6ULA is the ULA prefix used by the app connector to assign IPv6 addresses.
|
// v6ULA is the ULA prefix used by the app connector to assign IPv6 addresses.
|
||||||
@ -217,8 +231,8 @@ func ula(siteID uint16) netip.Prefix {
|
|||||||
//
|
//
|
||||||
// The passed in context is only used for the initial setup. The connector runs
|
// The passed in context is only used for the initial setup. The connector runs
|
||||||
// forever.
|
// forever.
|
||||||
func (c *connector) run(ctx context.Context) {
|
func (c *connector) run(ctx context.Context, lc *local.Client) {
|
||||||
if _, err := c.lc.EditPrefs(ctx, &ipn.MaskedPrefs{
|
if _, err := lc.EditPrefs(ctx, &ipn.MaskedPrefs{
|
||||||
AdvertiseRoutesSet: true,
|
AdvertiseRoutesSet: true,
|
||||||
Prefs: ipn.Prefs{
|
Prefs: ipn.Prefs{
|
||||||
AdvertiseRoutes: append(c.routes.Prefixes(), c.v6ULA),
|
AdvertiseRoutes: append(c.routes.Prefixes(), c.v6ULA),
|
||||||
@ -251,26 +265,6 @@ func (c *connector) serveDNS() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func lookupDestinationIP(domain string) ([]netip.Addr, error) {
|
|
||||||
netIPs, err := net.LookupIP(domain)
|
|
||||||
if err != nil {
|
|
||||||
var dnsError *net.DNSError
|
|
||||||
if errors.As(err, &dnsError) && dnsError.IsNotFound {
|
|
||||||
return nil, nil
|
|
||||||
} else {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
var addrs []netip.Addr
|
|
||||||
for _, ip := range netIPs {
|
|
||||||
a, ok := netip.AddrFromSlice(ip)
|
|
||||||
if ok {
|
|
||||||
addrs = append(addrs, a)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return addrs, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleDNS handles a DNS request to the app connector.
|
// handleDNS handles a DNS request to the app connector.
|
||||||
// It generates a response based on the request and the node that sent it.
|
// It generates a response based on the request and the node that sent it.
|
||||||
//
|
//
|
||||||
@ -285,7 +279,7 @@ func lookupDestinationIP(domain string) ([]netip.Addr, error) {
|
|||||||
func (c *connector) handleDNS(pc net.PacketConn, buf []byte, remoteAddr *net.UDPAddr) {
|
func (c *connector) handleDNS(pc net.PacketConn, buf []byte, remoteAddr *net.UDPAddr) {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
who, err := c.lc.WhoIs(ctx, remoteAddr.String())
|
who, err := c.whois.WhoIs(ctx, remoteAddr.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("HandleDNS(remote=%s): WhoIs failed: %v\n", remoteAddr.String(), err)
|
log.Printf("HandleDNS(remote=%s): WhoIs failed: %v\n", remoteAddr.String(), err)
|
||||||
return
|
return
|
||||||
@ -298,49 +292,122 @@ func (c *connector) handleDNS(pc net.PacketConn, buf []byte, remoteAddr *net.UDP
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// If there are destination ips that we don't want to route, we
|
var resolves map[string][]netip.Addr
|
||||||
// have to do a dns lookup here to find the destination ip.
|
var addrQCount int
|
||||||
if c.ignoreDsts != nil {
|
for _, q := range msg.Questions {
|
||||||
if len(msg.Questions) > 0 {
|
if q.Type != dnsmessage.TypeA && q.Type != dnsmessage.TypeAAAA {
|
||||||
q := msg.Questions[0]
|
continue
|
||||||
switch q.Type {
|
}
|
||||||
case dnsmessage.TypeAAAA, dnsmessage.TypeA:
|
addrQCount++
|
||||||
dstAddrs, err := lookupDestinationIP(q.Name.String())
|
if _, ok := resolves[q.Name.String()]; !ok {
|
||||||
|
addrs, err := c.resolver.LookupNetIP(ctx, "ip", q.Name.String())
|
||||||
|
var dnsErr *net.DNSError
|
||||||
|
if errors.As(err, &dnsErr) && dnsErr.IsNotFound {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("HandleDNS(remote=%s): lookup destination failed: %v\n", remoteAddr.String(), err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Note: If _any_ destination is ignored, pass through all of the resolved
|
||||||
|
// addresses as-is.
|
||||||
|
//
|
||||||
|
// This could result in some odd split-routing if there was a mix of
|
||||||
|
// ignored and non-ignored addresses, but it's currently the user
|
||||||
|
// preferred behavior.
|
||||||
|
if !c.ignoreDestination(addrs) {
|
||||||
|
addrs, err = c.ipPool.IPForDomain(who.Node.ID, q.Name.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("HandleDNS(remote=%s): lookup destination failed: %v\n", remoteAddr.String(), err)
|
log.Printf("HandleDNS(remote=%s): lookup destination failed: %v\n", remoteAddr.String(), err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if c.ignoreDestination(dstAddrs) {
|
}
|
||||||
bs, err := dnsResponse(&msg, dstAddrs)
|
mak.Set(&resolves, q.Name.String(), addrs)
|
||||||
// TODO (fran): treat as SERVFAIL
|
}
|
||||||
if err != nil {
|
}
|
||||||
log.Printf("HandleDNS(remote=%s): generate ignore response failed: %v\n", remoteAddr.String(), err)
|
|
||||||
return
|
rcode := dnsmessage.RCodeSuccess
|
||||||
}
|
if addrQCount > 0 && len(resolves) == 0 {
|
||||||
_, err = pc.WriteTo(bs, remoteAddr)
|
rcode = dnsmessage.RCodeNameError
|
||||||
if err != nil {
|
}
|
||||||
log.Printf("HandleDNS(remote=%s): write failed: %v\n", remoteAddr.String(), err)
|
|
||||||
}
|
b := dnsmessage.NewBuilder(nil,
|
||||||
|
dnsmessage.Header{
|
||||||
|
ID: msg.Header.ID,
|
||||||
|
Response: true,
|
||||||
|
Authoritative: true,
|
||||||
|
RCode: rcode,
|
||||||
|
})
|
||||||
|
b.EnableCompression()
|
||||||
|
|
||||||
|
if err := b.StartQuestions(); err != nil {
|
||||||
|
log.Printf("HandleDNS(remote=%s): dnsmessage start questions failed: %v\n", remoteAddr.String(), err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, q := range msg.Questions {
|
||||||
|
b.Question(q)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := b.StartAnswers(); err != nil {
|
||||||
|
log.Printf("HandleDNS(remote=%s): dnsmessage start answers failed: %v\n", remoteAddr.String(), err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, q := range msg.Questions {
|
||||||
|
switch q.Type {
|
||||||
|
case dnsmessage.TypeSOA:
|
||||||
|
if err := b.SOAResource(
|
||||||
|
dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
|
||||||
|
dnsmessage.SOAResource{NS: q.Name, MBox: tsMBox, Serial: 2023030600,
|
||||||
|
Refresh: 120, Retry: 120, Expire: 120, MinTTL: 60},
|
||||||
|
); err != nil {
|
||||||
|
log.Printf("HandleDNS(remote=%s): dnsmessage SOA resource failed: %v\n", remoteAddr.String(), err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case dnsmessage.TypeNS:
|
||||||
|
if err := b.NSResource(
|
||||||
|
dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
|
||||||
|
dnsmessage.NSResource{NS: tsMBox},
|
||||||
|
); err != nil {
|
||||||
|
log.Printf("HandleDNS(remote=%s): dnsmessage NS resource failed: %v\n", remoteAddr.String(), err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case dnsmessage.TypeAAAA:
|
||||||
|
for _, addr := range resolves[q.Name.String()] {
|
||||||
|
if !addr.Is6() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := b.AAAAResource(
|
||||||
|
dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
|
||||||
|
dnsmessage.AAAAResource{AAAA: addr.As16()},
|
||||||
|
); err != nil {
|
||||||
|
log.Printf("HandleDNS(remote=%s): dnsmessage AAAA resource failed: %v\n", remoteAddr.String(), err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case dnsmessage.TypeA:
|
||||||
|
for _, addr := range resolves[q.Name.String()] {
|
||||||
|
if !addr.Is4() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := b.AResource(
|
||||||
|
dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
|
||||||
|
dnsmessage.AResource{A: addr.As4()},
|
||||||
|
); err != nil {
|
||||||
|
log.Printf("HandleDNS(remote=%s): dnsmessage A resource failed: %v\n", remoteAddr.String(), err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// None of the destination IP addresses match an ignore destination prefix, do
|
|
||||||
// the natc thing.
|
|
||||||
|
|
||||||
resp, err := c.generateDNSResponse(&msg, who.Node.ID)
|
out, err := b.Finish()
|
||||||
// TODO (fran): treat as SERVFAIL
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("HandleDNS(remote=%s): connector handling failed: %v\n", remoteAddr.String(), err)
|
log.Printf("HandleDNS(remote=%s): dnsmessage finish failed: %v\n", remoteAddr.String(), err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// TODO (fran): treat as NXDOMAIN
|
_, err = pc.WriteTo(out, remoteAddr)
|
||||||
if len(resp) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// This connector handled the DNS request
|
|
||||||
_, err = pc.WriteTo(resp, remoteAddr)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("HandleDNS(remote=%s): write failed: %v\n", remoteAddr.String(), err)
|
log.Printf("HandleDNS(remote=%s): write failed: %v\n", remoteAddr.String(), err)
|
||||||
}
|
}
|
||||||
@ -352,89 +419,6 @@ func (c *connector) handleDNS(pc net.PacketConn, buf []byte, remoteAddr *net.UDP
|
|||||||
// to indicate that it is a fully qualified domain name.
|
// to indicate that it is a fully qualified domain name.
|
||||||
var tsMBox = dnsmessage.MustNewName("support.tailscale.com.")
|
var tsMBox = dnsmessage.MustNewName("support.tailscale.com.")
|
||||||
|
|
||||||
// generateDNSResponse generates a DNS response for the given request. The from
|
|
||||||
// argument is the NodeID of the node that sent the request.
|
|
||||||
func (c *connector) generateDNSResponse(req *dnsmessage.Message, from tailcfg.NodeID) ([]byte, error) {
|
|
||||||
var addrs []netip.Addr
|
|
||||||
if len(req.Questions) > 0 {
|
|
||||||
switch req.Questions[0].Type {
|
|
||||||
case dnsmessage.TypeAAAA, dnsmessage.TypeA:
|
|
||||||
var err error
|
|
||||||
addrs, err = c.ipPool.IPForDomain(from, req.Questions[0].Name.String())
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return dnsResponse(req, addrs)
|
|
||||||
}
|
|
||||||
|
|
||||||
// dnsResponse makes a DNS response for the natc. If the dnsmessage is requesting TypeAAAA
|
|
||||||
// or TypeA the provided addrs of the requested type will be used.
|
|
||||||
func dnsResponse(req *dnsmessage.Message, addrs []netip.Addr) ([]byte, error) {
|
|
||||||
b := dnsmessage.NewBuilder(nil,
|
|
||||||
dnsmessage.Header{
|
|
||||||
ID: req.Header.ID,
|
|
||||||
Response: true,
|
|
||||||
Authoritative: true,
|
|
||||||
})
|
|
||||||
b.EnableCompression()
|
|
||||||
|
|
||||||
if len(req.Questions) == 0 {
|
|
||||||
return b.Finish()
|
|
||||||
}
|
|
||||||
q := req.Questions[0]
|
|
||||||
if err := b.StartQuestions(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err := b.Question(q); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err := b.StartAnswers(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
switch q.Type {
|
|
||||||
case dnsmessage.TypeAAAA, dnsmessage.TypeA:
|
|
||||||
want6 := q.Type == dnsmessage.TypeAAAA
|
|
||||||
for _, ip := range addrs {
|
|
||||||
if want6 != ip.Is6() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if want6 {
|
|
||||||
if err := b.AAAAResource(
|
|
||||||
dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 5},
|
|
||||||
dnsmessage.AAAAResource{AAAA: ip.As16()},
|
|
||||||
); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if err := b.AResource(
|
|
||||||
dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 5},
|
|
||||||
dnsmessage.AResource{A: ip.As4()},
|
|
||||||
); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case dnsmessage.TypeSOA:
|
|
||||||
if err := b.SOAResource(
|
|
||||||
dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
|
|
||||||
dnsmessage.SOAResource{NS: q.Name, MBox: tsMBox, Serial: 2023030600,
|
|
||||||
Refresh: 120, Retry: 120, Expire: 120, MinTTL: 60},
|
|
||||||
); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
case dnsmessage.TypeNS:
|
|
||||||
if err := b.NSResource(
|
|
||||||
dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
|
|
||||||
dnsmessage.NSResource{NS: tsMBox},
|
|
||||||
); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return b.Finish()
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleTCPFlow handles a TCP flow from the given source to the given
|
// handleTCPFlow handles a TCP flow from the given source to the given
|
||||||
// destination. It uses the source address to determine the node that sent the
|
// destination. It uses the source address to determine the node that sent the
|
||||||
// request and the destination address to determine the domain that the request
|
// request and the destination address to determine the domain that the request
|
||||||
@ -443,7 +427,7 @@ func dnsResponse(req *dnsmessage.Message, addrs []netip.Addr) ([]byte, error) {
|
|||||||
func (c *connector) handleTCPFlow(src, dst netip.AddrPort) (handler func(net.Conn), intercept bool) {
|
func (c *connector) handleTCPFlow(src, dst netip.AddrPort) (handler func(net.Conn), intercept bool) {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
who, err := c.lc.WhoIs(ctx, src.Addr().String())
|
who, err := c.whois.WhoIs(ctx, src.Addr().String())
|
||||||
cancel()
|
cancel()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("HandleTCPFlow: WhoIs failed: %v\n", err)
|
log.Printf("HandleTCPFlow: WhoIs failed: %v\n", err)
|
||||||
@ -461,6 +445,9 @@ func (c *connector) handleTCPFlow(src, dst netip.AddrPort) (handler func(net.Con
|
|||||||
// ignoreDestination reports whether any of the provided dstAddrs match the prefixes configured
|
// ignoreDestination reports whether any of the provided dstAddrs match the prefixes configured
|
||||||
// in --ignore-destinations
|
// in --ignore-destinations
|
||||||
func (c *connector) ignoreDestination(dstAddrs []netip.Addr) bool {
|
func (c *connector) ignoreDestination(dstAddrs []netip.Addr) bool {
|
||||||
|
if c.ignoreDsts == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
for _, a := range dstAddrs {
|
for _, a := range dstAddrs {
|
||||||
if _, ok := c.ignoreDsts.Lookup(a); ok {
|
if _, ok := c.ignoreDsts.Lookup(a); ok {
|
||||||
return true
|
return true
|
||||||
@ -488,6 +475,8 @@ func proxyTCPConn(c net.Conn, dest string) {
|
|||||||
return netutil.NewOneConnListener(c, nil), nil
|
return netutil.NewOneConnListener(c, nil), nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
// XXX(raggi): if the connection here resolves to an ignored destination,
|
||||||
|
// the connection should be closed/failed.
|
||||||
p.AddRoute(addrPortStr, &tcpproxy.DialProxy{
|
p.AddRoute(addrPortStr, &tcpproxy.DialProxy{
|
||||||
Addr: fmt.Sprintf("%s:%s", dest, port),
|
Addr: fmt.Sprintf("%s:%s", dest, port),
|
||||||
})
|
})
|
||||||
|
@ -4,14 +4,20 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/google/go-cmp/cmp"
|
|
||||||
"golang.org/x/net/dns/dnsmessage"
|
"golang.org/x/net/dns/dnsmessage"
|
||||||
|
"tailscale.com/client/tailscale/apitype"
|
||||||
"tailscale.com/cmd/natc/ippool"
|
"tailscale.com/cmd/natc/ippool"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
|
"tailscale.com/util/must"
|
||||||
)
|
)
|
||||||
|
|
||||||
func prefixEqual(a, b netip.Prefix) bool {
|
func prefixEqual(a, b netip.Prefix) bool {
|
||||||
@ -41,22 +47,86 @@ func TestULA(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type recordingPacketConn struct {
|
||||||
|
writes [][]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *recordingPacketConn) WriteTo(b []byte, addr net.Addr) (int, error) {
|
||||||
|
w.writes = append(w.writes, b)
|
||||||
|
return len(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *recordingPacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
|
||||||
|
return 0, nil, io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *recordingPacketConn) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *recordingPacketConn) LocalAddr() net.Addr {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *recordingPacketConn) RemoteAddr() net.Addr {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *recordingPacketConn) SetDeadline(t time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *recordingPacketConn) SetReadDeadline(t time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *recordingPacketConn) SetWriteDeadline(t time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type resolver struct {
|
||||||
|
resolves map[string][]netip.Addr
|
||||||
|
fails map[string]bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *resolver) LookupNetIP(ctx context.Context, _net, host string) ([]netip.Addr, error) {
|
||||||
|
if addrs, ok := r.resolves[host]; ok {
|
||||||
|
return addrs, nil
|
||||||
|
}
|
||||||
|
if _, ok := r.fails[host]; ok {
|
||||||
|
return nil, &net.DNSError{IsTimeout: false, IsNotFound: false, Name: host, IsTemporary: true}
|
||||||
|
}
|
||||||
|
return nil, &net.DNSError{IsNotFound: true, Name: host}
|
||||||
|
}
|
||||||
|
|
||||||
|
type whois struct {
|
||||||
|
peers map[string]*apitype.WhoIsResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *whois) WhoIs(ctx context.Context, remoteAddr string) (*apitype.WhoIsResponse, error) {
|
||||||
|
addr := netip.MustParseAddrPort(remoteAddr).Addr().String()
|
||||||
|
if peer, ok := w.peers[addr]; ok {
|
||||||
|
return peer, nil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("peer not found")
|
||||||
|
}
|
||||||
|
|
||||||
func TestDNSResponse(t *testing.T) {
|
func TestDNSResponse(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
questions []dnsmessage.Question
|
questions []dnsmessage.Question
|
||||||
addrs []netip.Addr
|
|
||||||
wantEmpty bool
|
wantEmpty bool
|
||||||
wantAnswers []struct {
|
wantAnswers []struct {
|
||||||
name string
|
name string
|
||||||
qType dnsmessage.Type
|
qType dnsmessage.Type
|
||||||
addr netip.Addr
|
addr netip.Addr
|
||||||
}
|
}
|
||||||
|
wantNXDOMAIN bool
|
||||||
|
wantIgnored bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "empty_request",
|
name: "empty_request",
|
||||||
questions: []dnsmessage.Question{},
|
questions: []dnsmessage.Question{},
|
||||||
addrs: []netip.Addr{},
|
|
||||||
wantEmpty: false,
|
wantEmpty: false,
|
||||||
wantAnswers: nil,
|
wantAnswers: nil,
|
||||||
},
|
},
|
||||||
@ -69,7 +139,6 @@ func TestDNSResponse(t *testing.T) {
|
|||||||
Class: dnsmessage.ClassINET,
|
Class: dnsmessage.ClassINET,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
addrs: []netip.Addr{netip.MustParseAddr("100.64.1.5")},
|
|
||||||
wantAnswers: []struct {
|
wantAnswers: []struct {
|
||||||
name string
|
name string
|
||||||
qType dnsmessage.Type
|
qType dnsmessage.Type
|
||||||
@ -78,7 +147,7 @@ func TestDNSResponse(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "example.com.",
|
name: "example.com.",
|
||||||
qType: dnsmessage.TypeA,
|
qType: dnsmessage.TypeA,
|
||||||
addr: netip.MustParseAddr("100.64.1.5"),
|
addr: netip.MustParseAddr("100.64.0.0"),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -91,7 +160,6 @@ func TestDNSResponse(t *testing.T) {
|
|||||||
Class: dnsmessage.ClassINET,
|
Class: dnsmessage.ClassINET,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
addrs: []netip.Addr{netip.MustParseAddr("fd7a:115c:a1e0:a99c:0001:0505:0505:0505")},
|
|
||||||
wantAnswers: []struct {
|
wantAnswers: []struct {
|
||||||
name string
|
name string
|
||||||
qType dnsmessage.Type
|
qType dnsmessage.Type
|
||||||
@ -100,7 +168,7 @@ func TestDNSResponse(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "example.com.",
|
name: "example.com.",
|
||||||
qType: dnsmessage.TypeAAAA,
|
qType: dnsmessage.TypeAAAA,
|
||||||
addr: netip.MustParseAddr("fd7a:115c:a1e0:a99c:0001:0505:0505:0505"),
|
addr: netip.MustParseAddr("fd7a:115c:a1e0::"),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -113,7 +181,6 @@ func TestDNSResponse(t *testing.T) {
|
|||||||
Class: dnsmessage.ClassINET,
|
Class: dnsmessage.ClassINET,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
addrs: []netip.Addr{},
|
|
||||||
wantAnswers: nil,
|
wantAnswers: nil,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -125,89 +192,210 @@ func TestDNSResponse(t *testing.T) {
|
|||||||
Class: dnsmessage.ClassINET,
|
Class: dnsmessage.ClassINET,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
addrs: []netip.Addr{},
|
|
||||||
wantAnswers: nil,
|
wantAnswers: nil,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "nxdomain",
|
||||||
|
questions: []dnsmessage.Question{
|
||||||
|
{
|
||||||
|
Name: dnsmessage.MustNewName("noexist.example.com."),
|
||||||
|
Type: dnsmessage.TypeA,
|
||||||
|
Class: dnsmessage.ClassINET,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantNXDOMAIN: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "servfail",
|
||||||
|
questions: []dnsmessage.Question{
|
||||||
|
{
|
||||||
|
Name: dnsmessage.MustNewName("fail.example.com."),
|
||||||
|
Type: dnsmessage.TypeA,
|
||||||
|
Class: dnsmessage.ClassINET,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantEmpty: true, // TODO: pass through instead?
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ignored",
|
||||||
|
questions: []dnsmessage.Question{
|
||||||
|
{
|
||||||
|
Name: dnsmessage.MustNewName("ignore.example.com."),
|
||||||
|
Type: dnsmessage.TypeA,
|
||||||
|
Class: dnsmessage.ClassINET,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantAnswers: []struct {
|
||||||
|
name string
|
||||||
|
qType dnsmessage.Type
|
||||||
|
addr netip.Addr
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "ignore.example.com.",
|
||||||
|
qType: dnsmessage.TypeA,
|
||||||
|
addr: netip.MustParseAddr("8.8.4.4"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantIgnored: true,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var rpc recordingPacketConn
|
||||||
|
remoteAddr := must.Get(net.ResolveUDPAddr("udp", "100.64.254.1:12345"))
|
||||||
|
|
||||||
|
routes, dnsAddr, addrPool := calculateAddresses([]netip.Prefix{netip.MustParsePrefix("10.64.0.0/24")})
|
||||||
|
v6ULA := ula(1)
|
||||||
|
c := connector{
|
||||||
|
resolver: &resolver{
|
||||||
|
resolves: map[string][]netip.Addr{
|
||||||
|
"example.com.": {
|
||||||
|
netip.MustParseAddr("8.8.8.8"),
|
||||||
|
netip.MustParseAddr("2001:4860:4860::8888"),
|
||||||
|
},
|
||||||
|
"ignore.example.com.": {
|
||||||
|
netip.MustParseAddr("8.8.4.4"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
fails: map[string]bool{
|
||||||
|
"fail.example.com.": true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
whois: &whois{
|
||||||
|
peers: map[string]*apitype.WhoIsResponse{
|
||||||
|
"100.64.254.1": {
|
||||||
|
Node: &tailcfg.Node{ID: 123},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ignoreDsts: &bart.Table[bool]{},
|
||||||
|
routes: routes,
|
||||||
|
v6ULA: v6ULA,
|
||||||
|
ipPool: &ippool.IPPool{V6ULA: v6ULA, IPSet: addrPool},
|
||||||
|
dnsAddr: dnsAddr,
|
||||||
|
}
|
||||||
|
c.ignoreDsts.Insert(netip.MustParsePrefix("8.8.4.4/32"), true)
|
||||||
|
|
||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
req := &dnsmessage.Message{
|
rb := dnsmessage.NewBuilder(nil,
|
||||||
Header: dnsmessage.Header{
|
dnsmessage.Header{
|
||||||
ID: 1234,
|
ID: 1234,
|
||||||
},
|
},
|
||||||
Questions: tc.questions,
|
)
|
||||||
|
must.Do(rb.StartQuestions())
|
||||||
|
for _, q := range tc.questions {
|
||||||
|
rb.Question(q)
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := dnsResponse(req, tc.addrs)
|
c.handleDNS(&rpc, must.Get(rb.Finish()), remoteAddr)
|
||||||
|
|
||||||
|
writes := rpc.writes
|
||||||
|
rpc.writes = rpc.writes[:0]
|
||||||
|
|
||||||
|
if tc.wantEmpty {
|
||||||
|
if len(writes) != 0 {
|
||||||
|
t.Errorf("handleDNS() returned non-empty response when expected empty")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !tc.wantEmpty && len(writes) != 1 {
|
||||||
|
t.Fatalf("handleDNS() returned an unexpected number of responses: %d, want 1", len(writes))
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := writes[0]
|
||||||
|
var msg dnsmessage.Message
|
||||||
|
err := msg.Unpack(resp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("dnsResponse() error = %v", err)
|
t.Fatalf("Failed to unpack response: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if tc.wantEmpty && len(resp) != 0 {
|
if !msg.Header.Response {
|
||||||
t.Errorf("dnsResponse() returned non-empty response when expected empty")
|
t.Errorf("Response header is not set")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !tc.wantEmpty && len(resp) == 0 {
|
if msg.Header.ID != 1234 {
|
||||||
t.Errorf("dnsResponse() returned empty response when expected non-empty")
|
t.Errorf("Response ID = %d, want %d", msg.Header.ID, 1234)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(resp) > 0 {
|
if len(tc.wantAnswers) > 0 {
|
||||||
var msg dnsmessage.Message
|
if len(msg.Answers) != len(tc.wantAnswers) {
|
||||||
err = msg.Unpack(resp)
|
t.Errorf("got %d answers, want %d:\n%s", len(msg.Answers), len(tc.wantAnswers), msg.GoString())
|
||||||
if err != nil {
|
} else {
|
||||||
t.Fatalf("Failed to unpack response: %v", err)
|
for i, want := range tc.wantAnswers {
|
||||||
}
|
ans := msg.Answers[i]
|
||||||
|
|
||||||
if !msg.Header.Response {
|
gotName := ans.Header.Name.String()
|
||||||
t.Errorf("Response header is not set")
|
if gotName != want.name {
|
||||||
}
|
t.Errorf("answer[%d] name = %s, want %s", i, gotName, want.name)
|
||||||
|
}
|
||||||
|
|
||||||
if msg.Header.ID != req.Header.ID {
|
if ans.Header.Type != want.qType {
|
||||||
t.Errorf("Response ID = %d, want %d", msg.Header.ID, req.Header.ID)
|
t.Errorf("answer[%d] type = %v, want %v", i, ans.Header.Type, want.qType)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(tc.wantAnswers) > 0 {
|
switch want.qType {
|
||||||
if len(msg.Answers) != len(tc.wantAnswers) {
|
case dnsmessage.TypeA:
|
||||||
t.Errorf("got %d answers, want %d", len(msg.Answers), len(tc.wantAnswers))
|
if ans.Body.(*dnsmessage.AResource) == nil {
|
||||||
} else {
|
t.Errorf("answer[%d] not an A record", i)
|
||||||
for i, want := range tc.wantAnswers {
|
continue
|
||||||
ans := msg.Answers[i]
|
|
||||||
|
|
||||||
gotName := ans.Header.Name.String()
|
|
||||||
if gotName != want.name {
|
|
||||||
t.Errorf("answer[%d] name = %s, want %s", i, gotName, want.name)
|
|
||||||
}
|
}
|
||||||
|
resource := ans.Body.(*dnsmessage.AResource)
|
||||||
|
gotIP := netip.AddrFrom4([4]byte(resource.A))
|
||||||
|
|
||||||
if ans.Header.Type != want.qType {
|
var ips []netip.Addr
|
||||||
t.Errorf("answer[%d] type = %v, want %v", i, ans.Header.Type, want.qType)
|
if tc.wantIgnored {
|
||||||
|
ips = must.Get(c.resolver.LookupNetIP(t.Context(), "ip4", want.name))
|
||||||
|
} else {
|
||||||
|
ips = must.Get(c.ipPool.IPForDomain(tailcfg.NodeID(123), want.name))
|
||||||
}
|
}
|
||||||
|
var wantIP netip.Addr
|
||||||
var gotIP netip.Addr
|
for _, ip := range ips {
|
||||||
switch want.qType {
|
if ip.Is4() {
|
||||||
case dnsmessage.TypeA:
|
wantIP = ip
|
||||||
if ans.Body.(*dnsmessage.AResource) == nil {
|
break
|
||||||
t.Errorf("answer[%d] not an A record", i)
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
resource := ans.Body.(*dnsmessage.AResource)
|
|
||||||
gotIP = netip.AddrFrom4([4]byte(resource.A))
|
|
||||||
case dnsmessage.TypeAAAA:
|
|
||||||
if ans.Body.(*dnsmessage.AAAAResource) == nil {
|
|
||||||
t.Errorf("answer[%d] not an AAAA record", i)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
resource := ans.Body.(*dnsmessage.AAAAResource)
|
|
||||||
gotIP = netip.AddrFrom16([16]byte(resource.AAAA))
|
|
||||||
}
|
}
|
||||||
|
if gotIP != wantIP {
|
||||||
|
t.Errorf("answer[%d] IP = %s, want %s", i, gotIP, wantIP)
|
||||||
|
}
|
||||||
|
case dnsmessage.TypeAAAA:
|
||||||
|
if ans.Body.(*dnsmessage.AAAAResource) == nil {
|
||||||
|
t.Errorf("answer[%d] not an AAAA record", i)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
resource := ans.Body.(*dnsmessage.AAAAResource)
|
||||||
|
gotIP := netip.AddrFrom16([16]byte(resource.AAAA))
|
||||||
|
|
||||||
if gotIP != want.addr {
|
var ips []netip.Addr
|
||||||
t.Errorf("answer[%d] IP = %s, want %s", i, gotIP, want.addr)
|
if tc.wantIgnored {
|
||||||
|
ips = must.Get(c.resolver.LookupNetIP(t.Context(), "ip6", want.name))
|
||||||
|
} else {
|
||||||
|
ips = must.Get(c.ipPool.IPForDomain(tailcfg.NodeID(123), want.name))
|
||||||
|
}
|
||||||
|
var wantIP netip.Addr
|
||||||
|
for _, ip := range ips {
|
||||||
|
if ip.Is6() {
|
||||||
|
wantIP = ip
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if gotIP != wantIP {
|
||||||
|
t.Errorf("answer[%d] IP = %s, want %s", i, gotIP, wantIP)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if tc.wantNXDOMAIN {
|
||||||
|
if msg.RCode != dnsmessage.RCodeNameError {
|
||||||
|
t.Errorf("expected NXDOMAIN, got %v", msg.RCode)
|
||||||
|
}
|
||||||
|
if len(msg.Answers) != 0 {
|
||||||
|
t.Errorf("expected no answers, got %d", len(msg.Answers))
|
||||||
|
}
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -257,53 +445,3 @@ func TestIgnoreDestination(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConnectorGenerateDNSResponse(t *testing.T) {
|
|
||||||
v6ULA := netip.MustParsePrefix("fd7a:115c:a1e0:a99c:0001::/80")
|
|
||||||
routes, dnsAddr, addrPool := calculateAddresses([]netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")})
|
|
||||||
c := &connector{
|
|
||||||
v6ULA: v6ULA,
|
|
||||||
ipPool: &ippool.IPPool{V6ULA: v6ULA, IPSet: addrPool},
|
|
||||||
routes: routes,
|
|
||||||
dnsAddr: dnsAddr,
|
|
||||||
}
|
|
||||||
|
|
||||||
req := &dnsmessage.Message{
|
|
||||||
Header: dnsmessage.Header{ID: 1234},
|
|
||||||
Questions: []dnsmessage.Question{
|
|
||||||
{
|
|
||||||
Name: dnsmessage.MustNewName("example.com."),
|
|
||||||
Type: dnsmessage.TypeA,
|
|
||||||
Class: dnsmessage.ClassINET,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
nodeID := tailcfg.NodeID(12345)
|
|
||||||
|
|
||||||
resp1, err := c.generateDNSResponse(req, nodeID)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("generateDNSResponse() error = %v", err)
|
|
||||||
}
|
|
||||||
if len(resp1) == 0 {
|
|
||||||
t.Fatalf("generateDNSResponse() returned empty response")
|
|
||||||
}
|
|
||||||
|
|
||||||
resp2, err := c.generateDNSResponse(req, nodeID)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("generateDNSResponse() second call error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !cmp.Equal(resp1, resp2) {
|
|
||||||
t.Errorf("generateDNSResponse() responses differ between calls")
|
|
||||||
}
|
|
||||||
|
|
||||||
var msg dnsmessage.Message
|
|
||||||
err = msg.Unpack(resp1)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("dnsmessage Unpack error = %v", err)
|
|
||||||
}
|
|
||||||
if len(msg.Answers) != 1 {
|
|
||||||
t.Fatalf("expected 1 answer, got: %d", len(msg.Answers))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user