From 72c8f7700b5e477fbd197c97105a364cbd3380a3 Mon Sep 17 00:00:00 2001 From: Andrew Dunham Date: Thu, 13 Jun 2024 10:48:45 -0700 Subject: [PATCH] wgengine/netstack: add test for #12448 This refactors the logic for determining whether a packet should be sent to the host or not into a function, and then adds tests for it. Updates #11304 Updates #12448 Signed-off-by: Andrew Dunham Change-Id: Ief9afa98eaffae00e21ceb7db073c61b170355e5 --- wgengine/netstack/netstack.go | 93 +++++++----- wgengine/netstack/netstack_test.go | 231 ++++++++++++++++++++++++++++- 2 files changed, 275 insertions(+), 49 deletions(-) diff --git a/wgengine/netstack/netstack.go b/wgengine/netstack/netstack.go index a385b9593..fde9ab651 100644 --- a/wgengine/netstack/netstack.go +++ b/wgengine/netstack/netstack.go @@ -808,47 +808,7 @@ func (ns *Impl) inject() { // However, some uses of netstack (presently, magic DNS) // send traffic destined for the local device, hence must // be injected 'inbound'. - sendToHost := false - - // Determine if the packet is from a service IP, in which case it - // needs to go back into the machines network (inbound) instead of - // out. - // TODO(tom): Figure out if its safe to modify packet.Parsed to fill in - // the IP src/dest even if its missing the rest of the pkt. - // That way we dont have to do this twitchy-af byte-yeeting. - hdr := pkt.Network() - switch v := hdr.(type) { - case header.IPv4: - srcIP := netip.AddrFrom4(v.SourceAddress().As4()) - if serviceIP == srcIP { - sendToHost = true - } - case header.IPv6: - srcIP := netip.AddrFrom16(v.SourceAddress().As16()) - if srcIP == serviceIPv6 { - sendToHost = true - } else if viaRange.Contains(srcIP) { - // Only send to the host if this 4via6 route is - // something this node handles. - if ns.lb != nil && ns.lb.ShouldHandleViaIP(srcIP) { - dstIP := netip.AddrFrom16(v.DestinationAddress().As16()) - // Also, only forward to the host if - // the packet is destined for a local - // IP; otherwise, we'd send traffic - // that's intended for another peer - // from the local 4via6 address to the - // host instead of outbound to - // WireGuard. See: - // https://github.com/tailscale/tailscale/issues/12448 - sendToHost = ns.isLocalIP(dstIP) - if debugNetstack() { - ns.logf("netstack: sending 4via6 packet to host: src=%v dst=%v", srcIP, dstIP) - } - } - } - default: - // unknown; don't forward to host - } + sendToHost := ns.shouldSendToHost(pkt) // pkt has a non-zero refcount, so injection methods takes // ownership of one count and will decrement on completion. @@ -866,6 +826,57 @@ func (ns *Impl) inject() { } } +// shouldSendToHost determines if the provided packet should be sent to the +// host (i.e the current machine running Tailscale), in which case it will +// return true. It will return false if the packet should be sent outbound, for +// transit via WireGuard to another Tailscale node. +func (ns *Impl) shouldSendToHost(pkt *stack.PacketBuffer) bool { + // Determine if the packet is from a service IP (100.100.100.100 or the + // IPv6 variant), in which case it needs to go back into the machine's + // network (inbound) instead of out. + hdr := pkt.Network() + switch v := hdr.(type) { + case header.IPv4: + srcIP := netip.AddrFrom4(v.SourceAddress().As4()) + if serviceIP == srcIP { + return true + } + + case header.IPv6: + srcIP := netip.AddrFrom16(v.SourceAddress().As16()) + if srcIP == serviceIPv6 { + return true + } + + if viaRange.Contains(srcIP) { + // Only send to the host if this 4via6 route is + // something this node handles. + if ns.lb != nil && ns.lb.ShouldHandleViaIP(srcIP) { + dstIP := netip.AddrFrom16(v.DestinationAddress().As16()) + // Also, only forward to the host if the packet + // is destined for a local IP; otherwise, we'd + // send traffic that's intended for another + // peer from the local 4via6 address to the + // host instead of outbound to WireGuard. See: + // https://github.com/tailscale/tailscale/issues/12448 + if ns.isLocalIP(dstIP) { + return true + } + if debugNetstack() { + ns.logf("netstack: sending 4via6 packet to host: src=%v dst=%v", srcIP, dstIP) + } + } + } + default: + // unknown; don't forward to host + if debugNetstack() { + ns.logf("netstack: unexpected packet in shouldSendToHost: %T", v) + } + } + + return false +} + // isLocalIP reports whether ip is a Tailscale IP assigned to this // node directly (but not a subnet-routed IP). func (ns *Impl) isLocalIP(ip netip.Addr) bool { diff --git a/wgengine/netstack/netstack_test.go b/wgengine/netstack/netstack_test.go index 7a3affda7..43287d876 100644 --- a/wgengine/netstack/netstack_test.go +++ b/wgengine/netstack/netstack_test.go @@ -13,8 +13,10 @@ "testing" "time" + "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" "tailscale.com/envknob" "tailscale.com/ipn" "tailscale.com/ipn/ipnlocal" @@ -94,12 +96,12 @@ func getMemStats() (ms runtime.MemStats) { return } -func makeNetstack(t *testing.T, config func(*Impl)) *Impl { +func makeNetstack(tb testing.TB, config func(*Impl)) *Impl { tunDev := tstun.NewFake() sys := &tsd.System{} sys.Set(new(mem.Store)) dialer := new(tsdial.Dialer) - logf := tstest.WhileTestRunningLogger(t) + logf := tstest.WhileTestRunningLogger(tb) eng, err := wgengine.NewUserspaceEngine(logf, wgengine.Config{ Tun: tunDev, Dialer: dialer, @@ -107,20 +109,20 @@ func makeNetstack(t *testing.T, config func(*Impl)) *Impl { HealthTracker: sys.HealthTracker(), }) if err != nil { - t.Fatal(err) + tb.Fatal(err) } - t.Cleanup(func() { eng.Close() }) + tb.Cleanup(func() { eng.Close() }) sys.Set(eng) ns, err := Create(logf, sys.Tun.Get(), eng, sys.MagicSock.Get(), dialer, sys.DNSManager.Get(), sys.ProxyMapper(), nil) if err != nil { - t.Fatal(err) + tb.Fatal(err) } - t.Cleanup(func() { ns.Close() }) + tb.Cleanup(func() { ns.Close() }) lb, err := ipnlocal.NewLocalBackend(logf, logid.PublicID{}, sys, 0) if err != nil { - t.Fatalf("NewLocalBackend: %v", err) + tb.Fatalf("NewLocalBackend: %v", err) } ns.atomicIsLocalIPFunc.Store(func(netip.Addr) bool { return true }) @@ -128,7 +130,7 @@ func makeNetstack(t *testing.T, config func(*Impl)) *Impl { config(ns) } if err := ns.Start(lb); err != nil { - t.Fatalf("Start: %v", err) + tb.Fatalf("Start: %v", err) } return ns } @@ -797,3 +799,216 @@ func TestHandleLocalPackets(t *testing.T) { } }) } + +func TestShouldSendToHost(t *testing.T) { + var ( + selfIP4 = netip.MustParseAddr("100.64.1.2") + selfIP6 = netip.MustParseAddr("fd7a:115c:a1e0::123") + ) + + makeTestNetstack := func(tb testing.TB) *Impl { + impl := makeNetstack(tb, func(impl *Impl) { + impl.ProcessSubnets = false + impl.ProcessLocalIPs = false + impl.atomicIsLocalIPFunc.Store(func(addr netip.Addr) bool { + return addr == selfIP4 || addr == selfIP6 + }) + }) + + prefs := ipn.NewPrefs() + prefs.AdvertiseRoutes = []netip.Prefix{ + // $ tailscale debug via 7 10.1.1.0/24 + // fd7a:115c:a1e0:b1a:0:7:a01:100/120 + netip.MustParsePrefix("fd7a:115c:a1e0:b1a:0:7:a01:100/120"), + } + _, err := impl.lb.EditPrefs(&ipn.MaskedPrefs{ + Prefs: *prefs, + AdvertiseRoutesSet: true, + }) + if err != nil { + tb.Fatalf("EditPrefs: %v", err) + } + return impl + } + + testCases := []struct { + name string + src, dst netip.AddrPort + want bool + }{ + // Reply from service IP to localhost should be sent to host, + // not over WireGuard. + { + name: "from_service_ip_to_localhost", + src: netip.AddrPortFrom(serviceIP, 53), + dst: netip.MustParseAddrPort("127.0.0.1:9999"), + want: true, + }, + { + name: "from_service_ip_to_localhost_v6", + src: netip.AddrPortFrom(serviceIPv6, 53), + dst: netip.MustParseAddrPort("[::1]:9999"), + want: true, + }, + // A reply from the local IP to a remote host isn't sent to the + // host, but rather over WireGuard. + { + name: "local_ip_to_remote", + src: netip.AddrPortFrom(selfIP4, 12345), + dst: netip.MustParseAddrPort("100.64.99.88:7777"), + want: false, + }, + { + name: "local_ip_to_remote_v6", + src: netip.AddrPortFrom(selfIP6, 12345), + dst: netip.MustParseAddrPort("[fd7a:115:a1e0::99]:7777"), + want: false, + }, + // A reply from a 4via6 address to a remote host isn't sent to + // the local host, but rather over WireGuard. See: + // https://github.com/tailscale/tailscale/issues/12448 + { + name: "4via6_to_remote", + + // $ tailscale debug via 7 10.1.1.99/24 + // fd7a:115c:a1e0:b1a:0:7:a01:163/120 + src: netip.MustParseAddrPort("[fd7a:115c:a1e0:b1a:0:7:a01:163]:12345"), + dst: netip.MustParseAddrPort("[fd7a:115:a1e0::99]:7777"), + want: false, + }, + // However, a reply from a 4via6 address to the local Tailscale + // IP for this host *is* sent to the local host. See: + // https://github.com/tailscale/tailscale/issues/11304 + { + name: "4via6_to_local", + + // $ tailscale debug via 7 10.1.1.99/24 + // fd7a:115c:a1e0:b1a:0:7:a01:163/120 + src: netip.MustParseAddrPort("[fd7a:115c:a1e0:b1a:0:7:a01:163]:12345"), + dst: netip.AddrPortFrom(selfIP6, 7777), + want: true, + }, + // Traffic from a 4via6 address that we're not handling to + // either the local Tailscale IP or a remote host is sent + // outbound. + // + // In most cases, we won't see this type of traffic in the + // shouldSendToHost function, but let's confirm. + { + name: "other_4via6_to_local", + + // $ tailscale debug via 4444 10.1.1.88/24 + // fd7a:115c:a1e0:b1a:0:7:a01:163/120 + src: netip.MustParseAddrPort("[fd7a:115c:a1e0:b1a:0:115c:a01:158]:12345"), + dst: netip.AddrPortFrom(selfIP6, 7777), + want: false, + }, + { + name: "other_4via6_to_remote", + + // $ tailscale debug via 4444 10.1.1.88/24 + // fd7a:115c:a1e0:b1a:0:7:a01:163/120 + src: netip.MustParseAddrPort("[fd7a:115c:a1e0:b1a:0:115c:a01:158]:12345"), + dst: netip.MustParseAddrPort("[fd7a:115:a1e0::99]:7777"), + want: false, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + var pkt *stack.PacketBuffer + if tt.src.Addr().Is4() { + pkt = makeUDP4PacketBuffer(tt.src, tt.dst) + } else { + pkt = makeUDP6PacketBuffer(tt.src, tt.dst) + } + + ns := makeTestNetstack(t) + if got := ns.shouldSendToHost(pkt); got != tt.want { + t.Errorf("shouldSendToHost returned %v, want %v", got, tt.want) + } + }) + } +} + +func makeUDP4PacketBuffer(src, dst netip.AddrPort) *stack.PacketBuffer { + if !src.Addr().Is4() || !dst.Addr().Is4() { + panic("src and dst must be IPv4") + } + + data := []byte("hello world\n") + + packetLen := header.IPv4MinimumSize + header.UDPMinimumSize + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: packetLen, + Payload: buffer.MakeWithData(data), + }) + + // Initialize the UDP header. + udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize)) + pkt.TransportProtocolNumber = header.UDPProtocolNumber + + length := uint16(pkt.Size()) + udp.Encode(&header.UDPFields{ + SrcPort: src.Port(), + DstPort: dst.Port(), + Length: length, + }) + + // Add IP header + ipHdr := header.IPv4(pkt.NetworkHeader().Push(header.IPv4MinimumSize)) + pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber + ipHdr.Encode(&header.IPv4Fields{ + TotalLength: uint16(packetLen), + Protocol: uint8(header.UDPProtocolNumber), + SrcAddr: tcpip.AddrFrom4(src.Addr().As4()), + DstAddr: tcpip.AddrFrom4(dst.Addr().As4()), + Checksum: 0, + }) + + return pkt +} + +func makeUDP6PacketBuffer(src, dst netip.AddrPort) *stack.PacketBuffer { + if !src.Addr().Is6() || !dst.Addr().Is6() { + panic("src and dst must be IPv6") + } + data := []byte("hello world\n") + + packetLen := header.IPv6MinimumSize + header.UDPMinimumSize + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: packetLen, + Payload: buffer.MakeWithData(data), + }) + + srcAddr := tcpip.AddrFrom16(src.Addr().As16()) + dstAddr := tcpip.AddrFrom16(dst.Addr().As16()) + + // Add IP header + ipHdr := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) + pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber + ipHdr.Encode(&header.IPv6Fields{ + SrcAddr: srcAddr, + DstAddr: dstAddr, + PayloadLength: uint16(header.UDPMinimumSize + len(data)), + TransportProtocol: header.UDPProtocolNumber, + HopLimit: 64, + }) + + // Initialize the UDP header. + udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize)) + pkt.TransportProtocolNumber = header.UDPProtocolNumber + + length := uint16(pkt.Size()) + udp.Encode(&header.UDPFields{ + SrcPort: src.Port(), + DstPort: dst.Port(), + Length: length, + }) + + // Calculate the UDP pseudo-header checksum. + xsum := header.PseudoHeaderChecksum(header.UDPProtocolNumber, srcAddr, dstAddr, uint16(len(udp))) + udp.SetChecksum(^udp.CalculateChecksum(xsum)) + + return pkt +}