diff --git a/cmd/tailscaled/tailscaled.go b/cmd/tailscaled/tailscaled.go index 237cdfb55..065dd6aeb 100644 --- a/cmd/tailscaled/tailscaled.go +++ b/cmd/tailscaled/tailscaled.go @@ -577,6 +577,8 @@ func getLocalBackend(ctx context.Context, logf logger.Logf, logID logid.PublicID } } if socksListener != nil || httpProxyListener != nil { + dialer.UserDialCustomResolverDial = dns.Quad100ResolverDial(ctx, sys.DNSManager.Get()) + var addrs []string if httpProxyListener != nil { hs := &http.Server{Handler: httpProxyHandler(dialer.UserDial)} diff --git a/net/dns/quad100.go b/net/dns/quad100.go new file mode 100644 index 000000000..39071c2ad --- /dev/null +++ b/net/dns/quad100.go @@ -0,0 +1,67 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package dns + +import ( + "bytes" + "context" + "net" + "net/netip" + "time" +) + +type ManagerConn struct { + Ctx context.Context + DnsManager *Manager + + rbuf bytes.Buffer +} + +var ( + _ net.Conn = (*ManagerConn)(nil) + _ net.PacketConn = (*ManagerConn)(nil) // be a PacketConn to change net.Resolver semantics +) + +func (*ManagerConn) Close() error { return nil } +func (*ManagerConn) LocalAddr() net.Addr { return todoAddr{} } +func (*ManagerConn) RemoteAddr() net.Addr { return todoAddr{} } +func (*ManagerConn) SetDeadline(t time.Time) error { return nil } +func (*ManagerConn) SetReadDeadline(t time.Time) error { return nil } +func (*ManagerConn) SetWriteDeadline(t time.Time) error { return nil } + +func (c *ManagerConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + return c.Write(p) +} + +func (c *ManagerConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + n, err = c.Read(p) + return n, todoAddr{}, err +} + +func (c *ManagerConn) Read(p []byte) (n int, err error) { + return c.rbuf.Read(p) +} + +func (c *ManagerConn) Write(packet []byte) (n int, err error) { + pkt, err := c.DnsManager.Query(c.Ctx, packet, "tcp", netip.AddrPort{}) + if err != nil { + return 0, err + } + c.rbuf.Write(pkt) + return len(packet), nil +} + +type todoAddr struct{} + +func (todoAddr) Network() string { return "unused" } +func (todoAddr) String() string { return "unused-todoAddr" } + +func Quad100ResolverDial(ctx context.Context, mgr *Manager) func(ctx context.Context, network, address string) (net.Conn, error) { + return func(ctx context.Context, network, address string) (net.Conn, error) { + return &ManagerConn{ + Ctx: ctx, + DnsManager: mgr, + }, nil + } +} diff --git a/net/dns/quad100_test.go b/net/dns/quad100_test.go new file mode 100644 index 000000000..663b9a1bc --- /dev/null +++ b/net/dns/quad100_test.go @@ -0,0 +1,115 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package dns + +import ( + "context" + "net" + "testing" + + dns "golang.org/x/net/dns/dnsmessage" + "tailscale.com/health" + "tailscale.com/net/netmon" + "tailscale.com/net/tsdial" + "tailscale.com/util/dnsname" +) + +func TestQuad100Conn(t *testing.T) { + f := fakeOSConfigurator{ + SplitDNS: true, + BaseConfig: OSConfig{ + Nameservers: mustIPs("8.8.8.8"), + SearchDomains: fqdns("coffee.shop"), + }, + } + m := NewManager(t.Logf, &f, new(health.Tracker), tsdial.NewDialer(netmon.NewStatic()), nil, nil, "") + m.resolver.TestOnlySetHook(f.SetResolver) + m.Set(Config{ + Hosts: hosts( + "dave.ts.net.", "1.2.3.4", + "matt.ts.net.", "2.3.4.5"), + Routes: upstreams("ts.net", ""), + SearchDomains: fqdns("tailscale.com", "universe.tf"), + }) + defer m.Down() + + q100 := &ManagerConn{ + Ctx: context.Background(), + DnsManager: m, + } + defer q100.Close() + + var b []byte + domain := dnsname.FQDN("matt.ts.net.") + + // Send a query + b = mkDNSRequest(domain, dns.TypeA, addEDNS) + _, err := q100.Write(b) + if err != nil { + t.Fatal(err) + } + + resp := make([]byte, 100) + if _, err := q100.Read(resp); err != nil { + t.Fatalf("reading data: %v", err) + } + + var parser dns.Parser + if _, err := parser.Start(resp); err != nil { + t.Errorf("parser.Start() failed: %v", err) + } + _, err = parser.Question() + if err != nil { + t.Errorf("parser.Question(): %v", err) + } + if err := parser.SkipAllQuestions(); err != nil { + t.Errorf("parser.SkipAllQuestions(): %v", err) + } + ah, err := parser.AnswerHeader() + if err != nil { + t.Errorf("parser.AnswerHeader(): %v", err) + } + if ah.Type != dns.TypeA { + t.Errorf("unexpected answer type: got %v, want %v", ah.Type, dns.TypeA) + } + res, err := parser.AResource() + if err != nil { + t.Errorf("parser.AResource(): %v", err) + } + if net.IP(res.A[:]).String() != "2.3.4.5" { + t.Fatalf("dns query did not return expected result") + } +} + +func TestQuad100ResolverDial(t *testing.T) { + f := fakeOSConfigurator{ + SplitDNS: true, + BaseConfig: OSConfig{ + Nameservers: mustIPs("8.8.8.8"), + SearchDomains: fqdns("coffee.shop"), + }, + } + m := NewManager(t.Logf, &f, new(health.Tracker), tsdial.NewDialer(netmon.NewStatic()), nil, nil, "") + m.resolver.TestOnlySetHook(f.SetResolver) + m.Set(Config{ + Hosts: hosts( + "dave.ts.net.", "1.2.3.4", + "matt.ts.net.", "2.3.4.5"), + Routes: upstreams("ts.net", ""), + SearchDomains: fqdns("tailscale.com", "universe.tf"), + }) + defer m.Down() + + var r net.Resolver + r.Dial = Quad100ResolverDial(context.Background(), m) + + ips, err := r.LookupHost(context.Background(), "matt.ts.net") + if err != nil { + t.Errorf("could not resolve host: %v", err) + } + + if ips[0] != "2.3.4.5" { + t.Fatalf("dns query did not return expected result") + } +} diff --git a/net/tsdial/tsdial.go b/net/tsdial/tsdial.go index 3606dd67f..f225d22ab 100644 --- a/net/tsdial/tsdial.go +++ b/net/tsdial/tsdial.go @@ -63,6 +63,10 @@ type Dialer struct { // If nil, it's not used. NetstackDialUDP func(context.Context, netip.AddrPort) (net.Conn, error) + // UserDialCustomResolverDial if non-nil is invoked by UserDial to resolve a destination address. + // It is invoked after the in-memory tailnet machine map. + UserDialCustomResolverDial func(context.Context, string, string) (net.Conn, error) + peerClientOnce sync.Once peerClient *http.Client @@ -308,16 +312,17 @@ func (d *Dialer) userDialResolve(ctx context.Context, network, addr string) (net return ipp, err } - // Otherwise, hit the network. - // TODO(bradfitz): wire up net/dnscache too. + // Try tsdns resolver next to resolve SplitDNS host, port, err := splitHostPort(addr) if err != nil { // addr is malformed. return netip.AddrPort{}, err } + // Otherwise, hit the network. + var r net.Resolver if exitDNSDoH != "" && runtime.GOOS != "windows" { // Windows: https://github.com/golang/go/issues/33097 r.PreferGo = true @@ -329,6 +334,9 @@ func (d *Dialer) userDialResolve(ctx context.Context, network, addr string) (net dnsCache: d.dnsCache, }, nil } + } else if d.UserDialCustomResolverDial != nil { + r.PreferGo = true + r.Dial = d.UserDialCustomResolverDial } ips, err := r.LookupIP(ctx, ipNetOfNetwork(network), host)