mirror of
https://github.com/tailscale/tailscale.git
synced 2024-11-25 19:15:34 +00:00
wgengine/netstack: add test for #12448
This refactors the logic for determining whether a packet should be sent to the host or not into a function, and then adds tests for it. Updates #11304 Updates #12448 Signed-off-by: Andrew Dunham <andrew@du.nham.ca> Change-Id: Ief9afa98eaffae00e21ceb7db073c61b170355e5
This commit is contained in:
parent
88f2d234a4
commit
72c8f7700b
@ -808,47 +808,7 @@ func (ns *Impl) inject() {
|
|||||||
// However, some uses of netstack (presently, magic DNS)
|
// However, some uses of netstack (presently, magic DNS)
|
||||||
// send traffic destined for the local device, hence must
|
// send traffic destined for the local device, hence must
|
||||||
// be injected 'inbound'.
|
// be injected 'inbound'.
|
||||||
sendToHost := false
|
sendToHost := ns.shouldSendToHost(pkt)
|
||||||
|
|
||||||
// Determine if the packet is from a service IP, in which case it
|
|
||||||
// needs to go back into the machines network (inbound) instead of
|
|
||||||
// out.
|
|
||||||
// TODO(tom): Figure out if its safe to modify packet.Parsed to fill in
|
|
||||||
// the IP src/dest even if its missing the rest of the pkt.
|
|
||||||
// That way we dont have to do this twitchy-af byte-yeeting.
|
|
||||||
hdr := pkt.Network()
|
|
||||||
switch v := hdr.(type) {
|
|
||||||
case header.IPv4:
|
|
||||||
srcIP := netip.AddrFrom4(v.SourceAddress().As4())
|
|
||||||
if serviceIP == srcIP {
|
|
||||||
sendToHost = true
|
|
||||||
}
|
|
||||||
case header.IPv6:
|
|
||||||
srcIP := netip.AddrFrom16(v.SourceAddress().As16())
|
|
||||||
if srcIP == serviceIPv6 {
|
|
||||||
sendToHost = true
|
|
||||||
} else if viaRange.Contains(srcIP) {
|
|
||||||
// Only send to the host if this 4via6 route is
|
|
||||||
// something this node handles.
|
|
||||||
if ns.lb != nil && ns.lb.ShouldHandleViaIP(srcIP) {
|
|
||||||
dstIP := netip.AddrFrom16(v.DestinationAddress().As16())
|
|
||||||
// Also, only forward to the host if
|
|
||||||
// the packet is destined for a local
|
|
||||||
// IP; otherwise, we'd send traffic
|
|
||||||
// that's intended for another peer
|
|
||||||
// from the local 4via6 address to the
|
|
||||||
// host instead of outbound to
|
|
||||||
// WireGuard. See:
|
|
||||||
// https://github.com/tailscale/tailscale/issues/12448
|
|
||||||
sendToHost = ns.isLocalIP(dstIP)
|
|
||||||
if debugNetstack() {
|
|
||||||
ns.logf("netstack: sending 4via6 packet to host: src=%v dst=%v", srcIP, dstIP)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
// unknown; don't forward to host
|
|
||||||
}
|
|
||||||
|
|
||||||
// pkt has a non-zero refcount, so injection methods takes
|
// pkt has a non-zero refcount, so injection methods takes
|
||||||
// ownership of one count and will decrement on completion.
|
// ownership of one count and will decrement on completion.
|
||||||
@ -866,6 +826,57 @@ func (ns *Impl) inject() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// shouldSendToHost determines if the provided packet should be sent to the
|
||||||
|
// host (i.e the current machine running Tailscale), in which case it will
|
||||||
|
// return true. It will return false if the packet should be sent outbound, for
|
||||||
|
// transit via WireGuard to another Tailscale node.
|
||||||
|
func (ns *Impl) shouldSendToHost(pkt *stack.PacketBuffer) bool {
|
||||||
|
// Determine if the packet is from a service IP (100.100.100.100 or the
|
||||||
|
// IPv6 variant), in which case it needs to go back into the machine's
|
||||||
|
// network (inbound) instead of out.
|
||||||
|
hdr := pkt.Network()
|
||||||
|
switch v := hdr.(type) {
|
||||||
|
case header.IPv4:
|
||||||
|
srcIP := netip.AddrFrom4(v.SourceAddress().As4())
|
||||||
|
if serviceIP == srcIP {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
case header.IPv6:
|
||||||
|
srcIP := netip.AddrFrom16(v.SourceAddress().As16())
|
||||||
|
if srcIP == serviceIPv6 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if viaRange.Contains(srcIP) {
|
||||||
|
// Only send to the host if this 4via6 route is
|
||||||
|
// something this node handles.
|
||||||
|
if ns.lb != nil && ns.lb.ShouldHandleViaIP(srcIP) {
|
||||||
|
dstIP := netip.AddrFrom16(v.DestinationAddress().As16())
|
||||||
|
// Also, only forward to the host if the packet
|
||||||
|
// is destined for a local IP; otherwise, we'd
|
||||||
|
// send traffic that's intended for another
|
||||||
|
// peer from the local 4via6 address to the
|
||||||
|
// host instead of outbound to WireGuard. See:
|
||||||
|
// https://github.com/tailscale/tailscale/issues/12448
|
||||||
|
if ns.isLocalIP(dstIP) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if debugNetstack() {
|
||||||
|
ns.logf("netstack: sending 4via6 packet to host: src=%v dst=%v", srcIP, dstIP)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// unknown; don't forward to host
|
||||||
|
if debugNetstack() {
|
||||||
|
ns.logf("netstack: unexpected packet in shouldSendToHost: %T", v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// isLocalIP reports whether ip is a Tailscale IP assigned to this
|
// isLocalIP reports whether ip is a Tailscale IP assigned to this
|
||||||
// node directly (but not a subnet-routed IP).
|
// node directly (but not a subnet-routed IP).
|
||||||
func (ns *Impl) isLocalIP(ip netip.Addr) bool {
|
func (ns *Impl) isLocalIP(ip netip.Addr) bool {
|
||||||
|
@ -13,8 +13,10 @@
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"gvisor.dev/gvisor/pkg/buffer"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip"
|
"gvisor.dev/gvisor/pkg/tcpip"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||||
"tailscale.com/envknob"
|
"tailscale.com/envknob"
|
||||||
"tailscale.com/ipn"
|
"tailscale.com/ipn"
|
||||||
"tailscale.com/ipn/ipnlocal"
|
"tailscale.com/ipn/ipnlocal"
|
||||||
@ -94,12 +96,12 @@ func getMemStats() (ms runtime.MemStats) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func makeNetstack(t *testing.T, config func(*Impl)) *Impl {
|
func makeNetstack(tb testing.TB, config func(*Impl)) *Impl {
|
||||||
tunDev := tstun.NewFake()
|
tunDev := tstun.NewFake()
|
||||||
sys := &tsd.System{}
|
sys := &tsd.System{}
|
||||||
sys.Set(new(mem.Store))
|
sys.Set(new(mem.Store))
|
||||||
dialer := new(tsdial.Dialer)
|
dialer := new(tsdial.Dialer)
|
||||||
logf := tstest.WhileTestRunningLogger(t)
|
logf := tstest.WhileTestRunningLogger(tb)
|
||||||
eng, err := wgengine.NewUserspaceEngine(logf, wgengine.Config{
|
eng, err := wgengine.NewUserspaceEngine(logf, wgengine.Config{
|
||||||
Tun: tunDev,
|
Tun: tunDev,
|
||||||
Dialer: dialer,
|
Dialer: dialer,
|
||||||
@ -107,20 +109,20 @@ func makeNetstack(t *testing.T, config func(*Impl)) *Impl {
|
|||||||
HealthTracker: sys.HealthTracker(),
|
HealthTracker: sys.HealthTracker(),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
tb.Fatal(err)
|
||||||
}
|
}
|
||||||
t.Cleanup(func() { eng.Close() })
|
tb.Cleanup(func() { eng.Close() })
|
||||||
sys.Set(eng)
|
sys.Set(eng)
|
||||||
|
|
||||||
ns, err := Create(logf, sys.Tun.Get(), eng, sys.MagicSock.Get(), dialer, sys.DNSManager.Get(), sys.ProxyMapper(), nil)
|
ns, err := Create(logf, sys.Tun.Get(), eng, sys.MagicSock.Get(), dialer, sys.DNSManager.Get(), sys.ProxyMapper(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
tb.Fatal(err)
|
||||||
}
|
}
|
||||||
t.Cleanup(func() { ns.Close() })
|
tb.Cleanup(func() { ns.Close() })
|
||||||
|
|
||||||
lb, err := ipnlocal.NewLocalBackend(logf, logid.PublicID{}, sys, 0)
|
lb, err := ipnlocal.NewLocalBackend(logf, logid.PublicID{}, sys, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("NewLocalBackend: %v", err)
|
tb.Fatalf("NewLocalBackend: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ns.atomicIsLocalIPFunc.Store(func(netip.Addr) bool { return true })
|
ns.atomicIsLocalIPFunc.Store(func(netip.Addr) bool { return true })
|
||||||
@ -128,7 +130,7 @@ func makeNetstack(t *testing.T, config func(*Impl)) *Impl {
|
|||||||
config(ns)
|
config(ns)
|
||||||
}
|
}
|
||||||
if err := ns.Start(lb); err != nil {
|
if err := ns.Start(lb); err != nil {
|
||||||
t.Fatalf("Start: %v", err)
|
tb.Fatalf("Start: %v", err)
|
||||||
}
|
}
|
||||||
return ns
|
return ns
|
||||||
}
|
}
|
||||||
@ -797,3 +799,216 @@ func TestHandleLocalPackets(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestShouldSendToHost(t *testing.T) {
|
||||||
|
var (
|
||||||
|
selfIP4 = netip.MustParseAddr("100.64.1.2")
|
||||||
|
selfIP6 = netip.MustParseAddr("fd7a:115c:a1e0::123")
|
||||||
|
)
|
||||||
|
|
||||||
|
makeTestNetstack := func(tb testing.TB) *Impl {
|
||||||
|
impl := makeNetstack(tb, func(impl *Impl) {
|
||||||
|
impl.ProcessSubnets = false
|
||||||
|
impl.ProcessLocalIPs = false
|
||||||
|
impl.atomicIsLocalIPFunc.Store(func(addr netip.Addr) bool {
|
||||||
|
return addr == selfIP4 || addr == selfIP6
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
prefs := ipn.NewPrefs()
|
||||||
|
prefs.AdvertiseRoutes = []netip.Prefix{
|
||||||
|
// $ tailscale debug via 7 10.1.1.0/24
|
||||||
|
// fd7a:115c:a1e0:b1a:0:7:a01:100/120
|
||||||
|
netip.MustParsePrefix("fd7a:115c:a1e0:b1a:0:7:a01:100/120"),
|
||||||
|
}
|
||||||
|
_, err := impl.lb.EditPrefs(&ipn.MaskedPrefs{
|
||||||
|
Prefs: *prefs,
|
||||||
|
AdvertiseRoutesSet: true,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
tb.Fatalf("EditPrefs: %v", err)
|
||||||
|
}
|
||||||
|
return impl
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
src, dst netip.AddrPort
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
// Reply from service IP to localhost should be sent to host,
|
||||||
|
// not over WireGuard.
|
||||||
|
{
|
||||||
|
name: "from_service_ip_to_localhost",
|
||||||
|
src: netip.AddrPortFrom(serviceIP, 53),
|
||||||
|
dst: netip.MustParseAddrPort("127.0.0.1:9999"),
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "from_service_ip_to_localhost_v6",
|
||||||
|
src: netip.AddrPortFrom(serviceIPv6, 53),
|
||||||
|
dst: netip.MustParseAddrPort("[::1]:9999"),
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
// A reply from the local IP to a remote host isn't sent to the
|
||||||
|
// host, but rather over WireGuard.
|
||||||
|
{
|
||||||
|
name: "local_ip_to_remote",
|
||||||
|
src: netip.AddrPortFrom(selfIP4, 12345),
|
||||||
|
dst: netip.MustParseAddrPort("100.64.99.88:7777"),
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "local_ip_to_remote_v6",
|
||||||
|
src: netip.AddrPortFrom(selfIP6, 12345),
|
||||||
|
dst: netip.MustParseAddrPort("[fd7a:115:a1e0::99]:7777"),
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
// A reply from a 4via6 address to a remote host isn't sent to
|
||||||
|
// the local host, but rather over WireGuard. See:
|
||||||
|
// https://github.com/tailscale/tailscale/issues/12448
|
||||||
|
{
|
||||||
|
name: "4via6_to_remote",
|
||||||
|
|
||||||
|
// $ tailscale debug via 7 10.1.1.99/24
|
||||||
|
// fd7a:115c:a1e0:b1a:0:7:a01:163/120
|
||||||
|
src: netip.MustParseAddrPort("[fd7a:115c:a1e0:b1a:0:7:a01:163]:12345"),
|
||||||
|
dst: netip.MustParseAddrPort("[fd7a:115:a1e0::99]:7777"),
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
// However, a reply from a 4via6 address to the local Tailscale
|
||||||
|
// IP for this host *is* sent to the local host. See:
|
||||||
|
// https://github.com/tailscale/tailscale/issues/11304
|
||||||
|
{
|
||||||
|
name: "4via6_to_local",
|
||||||
|
|
||||||
|
// $ tailscale debug via 7 10.1.1.99/24
|
||||||
|
// fd7a:115c:a1e0:b1a:0:7:a01:163/120
|
||||||
|
src: netip.MustParseAddrPort("[fd7a:115c:a1e0:b1a:0:7:a01:163]:12345"),
|
||||||
|
dst: netip.AddrPortFrom(selfIP6, 7777),
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
// Traffic from a 4via6 address that we're not handling to
|
||||||
|
// either the local Tailscale IP or a remote host is sent
|
||||||
|
// outbound.
|
||||||
|
//
|
||||||
|
// In most cases, we won't see this type of traffic in the
|
||||||
|
// shouldSendToHost function, but let's confirm.
|
||||||
|
{
|
||||||
|
name: "other_4via6_to_local",
|
||||||
|
|
||||||
|
// $ tailscale debug via 4444 10.1.1.88/24
|
||||||
|
// fd7a:115c:a1e0:b1a:0:7:a01:163/120
|
||||||
|
src: netip.MustParseAddrPort("[fd7a:115c:a1e0:b1a:0:115c:a01:158]:12345"),
|
||||||
|
dst: netip.AddrPortFrom(selfIP6, 7777),
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "other_4via6_to_remote",
|
||||||
|
|
||||||
|
// $ tailscale debug via 4444 10.1.1.88/24
|
||||||
|
// fd7a:115c:a1e0:b1a:0:7:a01:163/120
|
||||||
|
src: netip.MustParseAddrPort("[fd7a:115c:a1e0:b1a:0:115c:a01:158]:12345"),
|
||||||
|
dst: netip.MustParseAddrPort("[fd7a:115:a1e0::99]:7777"),
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range testCases {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var pkt *stack.PacketBuffer
|
||||||
|
if tt.src.Addr().Is4() {
|
||||||
|
pkt = makeUDP4PacketBuffer(tt.src, tt.dst)
|
||||||
|
} else {
|
||||||
|
pkt = makeUDP6PacketBuffer(tt.src, tt.dst)
|
||||||
|
}
|
||||||
|
|
||||||
|
ns := makeTestNetstack(t)
|
||||||
|
if got := ns.shouldSendToHost(pkt); got != tt.want {
|
||||||
|
t.Errorf("shouldSendToHost returned %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeUDP4PacketBuffer(src, dst netip.AddrPort) *stack.PacketBuffer {
|
||||||
|
if !src.Addr().Is4() || !dst.Addr().Is4() {
|
||||||
|
panic("src and dst must be IPv4")
|
||||||
|
}
|
||||||
|
|
||||||
|
data := []byte("hello world\n")
|
||||||
|
|
||||||
|
packetLen := header.IPv4MinimumSize + header.UDPMinimumSize
|
||||||
|
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||||
|
ReserveHeaderBytes: packetLen,
|
||||||
|
Payload: buffer.MakeWithData(data),
|
||||||
|
})
|
||||||
|
|
||||||
|
// Initialize the UDP header.
|
||||||
|
udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
|
||||||
|
pkt.TransportProtocolNumber = header.UDPProtocolNumber
|
||||||
|
|
||||||
|
length := uint16(pkt.Size())
|
||||||
|
udp.Encode(&header.UDPFields{
|
||||||
|
SrcPort: src.Port(),
|
||||||
|
DstPort: dst.Port(),
|
||||||
|
Length: length,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Add IP header
|
||||||
|
ipHdr := header.IPv4(pkt.NetworkHeader().Push(header.IPv4MinimumSize))
|
||||||
|
pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber
|
||||||
|
ipHdr.Encode(&header.IPv4Fields{
|
||||||
|
TotalLength: uint16(packetLen),
|
||||||
|
Protocol: uint8(header.UDPProtocolNumber),
|
||||||
|
SrcAddr: tcpip.AddrFrom4(src.Addr().As4()),
|
||||||
|
DstAddr: tcpip.AddrFrom4(dst.Addr().As4()),
|
||||||
|
Checksum: 0,
|
||||||
|
})
|
||||||
|
|
||||||
|
return pkt
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeUDP6PacketBuffer(src, dst netip.AddrPort) *stack.PacketBuffer {
|
||||||
|
if !src.Addr().Is6() || !dst.Addr().Is6() {
|
||||||
|
panic("src and dst must be IPv6")
|
||||||
|
}
|
||||||
|
data := []byte("hello world\n")
|
||||||
|
|
||||||
|
packetLen := header.IPv6MinimumSize + header.UDPMinimumSize
|
||||||
|
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||||
|
ReserveHeaderBytes: packetLen,
|
||||||
|
Payload: buffer.MakeWithData(data),
|
||||||
|
})
|
||||||
|
|
||||||
|
srcAddr := tcpip.AddrFrom16(src.Addr().As16())
|
||||||
|
dstAddr := tcpip.AddrFrom16(dst.Addr().As16())
|
||||||
|
|
||||||
|
// Add IP header
|
||||||
|
ipHdr := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
|
||||||
|
pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber
|
||||||
|
ipHdr.Encode(&header.IPv6Fields{
|
||||||
|
SrcAddr: srcAddr,
|
||||||
|
DstAddr: dstAddr,
|
||||||
|
PayloadLength: uint16(header.UDPMinimumSize + len(data)),
|
||||||
|
TransportProtocol: header.UDPProtocolNumber,
|
||||||
|
HopLimit: 64,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Initialize the UDP header.
|
||||||
|
udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
|
||||||
|
pkt.TransportProtocolNumber = header.UDPProtocolNumber
|
||||||
|
|
||||||
|
length := uint16(pkt.Size())
|
||||||
|
udp.Encode(&header.UDPFields{
|
||||||
|
SrcPort: src.Port(),
|
||||||
|
DstPort: dst.Port(),
|
||||||
|
Length: length,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Calculate the UDP pseudo-header checksum.
|
||||||
|
xsum := header.PseudoHeaderChecksum(header.UDPProtocolNumber, srcAddr, dstAddr, uint16(len(udp)))
|
||||||
|
udp.SetChecksum(^udp.CalculateChecksum(xsum))
|
||||||
|
|
||||||
|
return pkt
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user