From ddb4040aa0cd7cdb2eef064658d18e3e75de1c5d Mon Sep 17 00:00:00 2001 From: Charlotte Brandhorst-Satzkorn <46385858+catzkorn@users.noreply.github.com> Date: Tue, 2 May 2023 17:49:56 -0700 Subject: [PATCH] wgengine/magicsock: add address selection for wireguard only endpoints (#7979) This change introduces address selection for wireguard only endpoints. If a endpoint has not been used before, an address is randomly selected to be used based on information we know about, such as if they are able to use IPv4 or IPv6. When an address is initially selected, we also initiate a new ICMP ping to the endpoints addresses to determine which endpoint offers the best latency. This information is then used to update which endpoint we should be using based on the best possible route. If the latency is the same for a IPv4 and an IPv6 address, IPv6 will be used. Updates #7826 Signed-off-by: Charlotte Brandhorst-Satzkorn --- cmd/tailscaled/depaware.txt | 2 +- wgengine/magicsock/magicsock.go | 298 ++++++++++++---- wgengine/magicsock/magicsock_test.go | 488 ++++++++++++++++++++++++++- 3 files changed, 714 insertions(+), 74 deletions(-) diff --git a/cmd/tailscaled/depaware.txt b/cmd/tailscaled/depaware.txt index b51b3a3de..16c8c6bba 100644 --- a/cmd/tailscaled/depaware.txt +++ b/cmd/tailscaled/depaware.txt @@ -245,7 +245,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de 💣 tailscale.com/net/netstat from tailscale.com/ipn/ipnauth+ tailscale.com/net/netutil from tailscale.com/ipn/ipnlocal+ tailscale.com/net/packet from tailscale.com/net/tstun+ - tailscale.com/net/ping from tailscale.com/net/netcheck + tailscale.com/net/ping from tailscale.com/net/netcheck+ tailscale.com/net/portmapper from tailscale.com/net/netcheck+ tailscale.com/net/proxymux from tailscale.com/cmd/tailscaled tailscale.com/net/routetable from tailscale.com/doctor/routetable diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 95af25891..3223441c4 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -50,6 +50,7 @@ "tailscale.com/net/netmon" "tailscale.com/net/netns" "tailscale.com/net/packet" + "tailscale.com/net/ping" "tailscale.com/net/portmapper" "tailscale.com/net/sockstats" "tailscale.com/net/stun" @@ -59,6 +60,7 @@ "tailscale.com/tstime" "tailscale.com/tstime/mono" "tailscale.com/types/key" + "tailscale.com/types/lazy" "tailscale.com/types/logger" "tailscale.com/types/netmap" "tailscale.com/types/nettype" @@ -209,11 +211,16 @@ func (m *peerMap) upsertEndpoint(ep *endpoint, oldDiscoKey key.DiscoPublic) { if epDisco == nil || oldDiscoKey != epDisco.key { delete(m.nodesOfDisco[oldDiscoKey], ep.publicKey) } - if epDisco == nil { - // If the peer does not support Disco, but it does have an endpoint address, - // attempt to use that (e.g. WireGuardOnly peers). - if ep.bestAddr.AddrPort.IsValid() { - m.setNodeKeyForIPPort(ep.bestAddr.AddrPort, ep.publicKey) + if ep.isWireguardOnly { + // If the peer is a WireGuard only peer, add all of its endpoints. + + // TODO(raggi,catzkorn): this could mean that if a "isWireguardOnly" + // peer has, say, 192.168.0.2 and so does a tailscale peer, the + // wireguard one will win. That may not be the outcome that we want - + // perhaps we should prefer bestAddr.AddrPort if it is set? + // see tailscale/tailscale#7994 + for ipp := range ep.endpointState { + m.setNodeKeyForIPPort(ipp, ep.publicKey) } return @@ -473,6 +480,9 @@ type Conn struct { // peerLastDerp tracks which DERP node we last used to speak with a // peer. It's only used to quiet logging, so we only log on change. peerLastDerp map[key.NodePublic]int + + // wgPinger is the WireGuard only pinger used for latency measurements. + wgPinger lazy.SyncValue[*ping.Pinger] } // SetDebugLoggingEnabled controls whether spammy debug logging is enabled. @@ -2766,6 +2776,7 @@ func (c *Conn) SetNetworkMap(nm *netmap.NetworkMap) { sentPing: map[stun.TxID]sentPing{}, endpointState: map[netip.AddrPort]*endpointState{}, heartbeatDisabled: heartbeatDisabled, + isWireguardOnly: n.IsWireGuardOnly, } if len(n.Addresses) > 0 { ep.nodeAddr = n.Addresses[0].Addr() @@ -3143,6 +3154,11 @@ func (c *Conn) Close() error { for c.goroutinesRunningLocked() { c.muCond.Wait() } + + if pinger := c.getPinger(); pinger != nil { + pinger.Close() + } + return nil } @@ -4084,9 +4100,14 @@ type endpointDisco struct { short string // ShortString of discoKey. } -// endpoint is a wireguard/conn.Endpoint that picks the best -// available path to communicate with a peer, based on network -// conditions and what the peer supports. +// endpoint is a wireguard/conn.Endpoint. In wireguard-go and kernel WireGuard +// there is only one endpoint for a peer, but in Tailscale we distribute a +// number of possible endpoints for a peer which would include the all the +// likely addresses at which a peer may be reachable. This endpoint type holds +// the information required that when WiregGuard-Go wants to send to a +// particular peer (essentally represented by this endpoint type), the send +// function can use the currnetly best known Tailscale endpoint to send packets +// to the peer. type endpoint struct { // atomically accessed; declared first for alignment reasons lastRecv mono.Time @@ -4108,7 +4129,7 @@ type endpoint struct { heartBeatTimer *time.Timer // nil when idle lastSend mono.Time // last time there was outgoing packets sent to this peer (from wireguard-go) - lastFullPing mono.Time // last time we pinged all endpoints + lastFullPing mono.Time // last time we pinged all disco endpoints derpAddr netip.AddrPort // fallback/bootstrap path, if non-zero (non-zero for well-behaved clients) bestAddr addrLatency // best non-DERP path; zero if none @@ -4126,7 +4147,8 @@ type endpoint struct { heartbeatDisabled bool pathFinderRunning bool - expired bool // whether the node has expired + expired bool // whether the node has expired + isWireguardOnly bool // whether the endpoint is WireGuard only } type pendingCLIPing struct { @@ -4238,6 +4260,15 @@ func (st *endpointState) shouldDeleteLocked() bool { } } +// latencyLocked returns the most recent latency measurement, if any. +// endpoint.mu must be held. +func (st *endpointState) latencyLocked() (lat time.Duration, ok bool) { + if len(st.recentPongs) == 0 { + return 0, false + } + return st.recentPongs[st.recentPong].latency, true +} + func (de *endpoint) deleteEndpointLocked(why string, ep netip.AddrPort) { de.debugUpdates.Add(EndpointChange{ When: time.Now(), @@ -4321,17 +4352,87 @@ func (de *endpoint) DstToBytes() []byte { return packIPPort(de.fakeWGAddr) } // addrForSendLocked returns the address(es) that should be used for // sending the next packet. Zero, one, or both of UDP address and DERP -// addr may be non-zero. +// addr may be non-zero. If the endpoint is WireGuard only and does not have +// latency information, a bool is returned to indiciate that the +// WireGuard latency discovery pings should be sent. // // de.mu must be held. -func (de *endpoint) addrForSendLocked(now mono.Time) (udpAddr, derpAddr netip.AddrPort) { +func (de *endpoint) addrForSendLocked(now mono.Time) (udpAddr, derpAddr netip.AddrPort, sendWGPing bool) { udpAddr = de.bestAddr.AddrPort - if !udpAddr.IsValid() || now.After(de.trustBestAddrUntil) { - // We had a bestAddr but it expired so send both to it - // and DERP. - derpAddr = de.derpAddr + + if udpAddr.IsValid() && !now.After(de.trustBestAddrUntil) { + return udpAddr, netip.AddrPort{}, false } - return + + if de.isWireguardOnly { + // If the endpoint is wireguard-only, we don't have a DERP + // address to send to, so we have to send to the UDP address. + udpAddr, shouldPing := de.addrForWireGuardSendLocked(now) + return udpAddr, netip.AddrPort{}, shouldPing + } + + // We had a bestAddr but it expired so send both to it + // and DERP. + return udpAddr, de.derpAddr, false +} + +// addrForWireGuardSendLocked returns the address that should be used for +// sending the next packet. If a packet has never or not recently been sent to +// the endpoint, then a randomly selected address for the endpoint is returned, +// as well as a bool indiciating that WireGuard discovery pings should be started. +// If the addresses have latency information available, then the address with the +// best latency is used. +// +// de.mu must be held. +func (de *endpoint) addrForWireGuardSendLocked(now mono.Time) (udpAddr netip.AddrPort, shouldPing bool) { + // lowestLatency is a high duration initially, so we + // can be sure we're going to have a duration lower than this + // for the first latency retrieved. + lowestLatency := time.Hour + for ipp, state := range de.endpointState { + if latency, ok := state.latencyLocked(); ok { + if latency < lowestLatency || latency == lowestLatency && ipp.Addr().Is6() { + // If we have the same latency,IPv6 is prioritized. + // TODO(catzkorn): Consider a small increase in latency to use + // IPv6 in comparison to IPv4, when possible. + lowestLatency = latency + udpAddr = ipp + } + } + } + + if udpAddr.IsValid() { + // Set trustBestAddrUntil to an hour, so we will + // continue to use this address for a long period of time. + de.bestAddr.AddrPort = udpAddr + de.trustBestAddrUntil = now.Add(1 * time.Hour) + return udpAddr, false + } + + candidates := make([]netip.AddrPort, 0, len(de.endpointState)) + for ipp := range de.endpointState { + if ipp.Addr().Is4() && de.c.noV4.Load() { + continue + } + if ipp.Addr().Is6() && de.c.noV6.Load() { + continue + } + candidates = append(candidates, ipp) + } + // Randomly select an address to use until we retrieve latency information + // and give it a short trustBestAddrUntil time so we avoid flapping between + // addresses while waiting on latency information to be populated. + udpAddr = candidates[rand.Intn(len(candidates))] + de.bestAddr.AddrPort = udpAddr + if len(candidates) == 1 { + // if we only have one address that we can send data too, + // we should trust it for a longer period of time. + de.trustBestAddrUntil = now.Add(1 * time.Hour) + } else { + de.trustBestAddrUntil = now.Add(15 * time.Second) + } + + return udpAddr, len(candidates) > 1 } // heartbeat is called every heartbeatInterval to keep the best UDP path alive, @@ -4359,14 +4460,14 @@ func (de *endpoint) heartbeat() { } now := mono.Now() - udpAddr, _ := de.addrForSendLocked(now) + udpAddr, _, _ := de.addrForSendLocked(now) if udpAddr.IsValid() { // We have a preferred path. Ping that every 2 seconds. - de.startPingLocked(udpAddr, now, pingHeartbeat) + de.startDiscoPingLocked(udpAddr, now, pingHeartbeat) } if de.wantFullPingLocked(now) { - de.sendPingsLocked(now, true) + de.sendDiscoPingsLocked(now, true) } de.heartBeatTimer = time.AfterFunc(heartbeatInterval, de.heartbeat) @@ -4417,19 +4518,19 @@ func (de *endpoint) cliPing(res *ipnstate.PingResult, cb func(*ipnstate.PingResu de.pendingCLIPings = append(de.pendingCLIPings, pendingCLIPing{res, cb}) now := mono.Now() - udpAddr, derpAddr := de.addrForSendLocked(now) + udpAddr, derpAddr, _ := de.addrForSendLocked(now) if derpAddr.IsValid() { - de.startPingLocked(derpAddr, now, pingCLI) + de.startDiscoPingLocked(derpAddr, now, pingCLI) } if udpAddr.IsValid() && now.Before(de.trustBestAddrUntil) { // Already have an active session, so just ping the address we're using. // Otherwise "tailscale ping" results to a node on the local network // can look like they're bouncing between, say 10.0.0.0/9 and the peer's // IPv6 address, both 1ms away, and it's random who replies first. - de.startPingLocked(udpAddr, now, pingCLI) + de.startDiscoPingLocked(udpAddr, now, pingCLI) } else { for ep := range de.endpointState { - de.startPingLocked(ep, now, pingCLI) + de.startDiscoPingLocked(ep, now, pingCLI) } } de.noteActiveLocked() @@ -4459,9 +4560,14 @@ func (de *endpoint) send(buffs [][]byte) error { } now := mono.Now() - udpAddr, derpAddr := de.addrForSendLocked(now) - if !udpAddr.IsValid() || now.After(de.trustBestAddrUntil) { - de.sendPingsLocked(now, true) + udpAddr, derpAddr, startWGPing := de.addrForSendLocked(now) + + if de.isWireguardOnly { + if startWGPing { + de.sendWireGuardOnlyPingsLocked(now) + } + } else if !udpAddr.IsValid() || now.After(de.trustBestAddrUntil) { + de.sendDiscoPingsLocked(now, true) } de.noteActiveLocked() de.mu.Unlock() @@ -4499,7 +4605,7 @@ func (de *endpoint) send(buffs [][]byte) error { return err } -func (de *endpoint) pingTimeout(txid stun.TxID) { +func (de *endpoint) discoPingTimeout(txid stun.TxID) { de.mu.Lock() defer de.mu.Unlock() sp, ok := de.sentPing[txid] @@ -4509,20 +4615,20 @@ func (de *endpoint) pingTimeout(txid stun.TxID) { if debugDisco() || !de.bestAddr.IsValid() || mono.Now().After(de.trustBestAddrUntil) { de.c.dlogf("[v1] magicsock: disco: timeout waiting for pong %x from %v (%v, %v)", txid[:6], sp.to, de.publicKey.ShortString(), de.discoShort()) } - de.removeSentPingLocked(txid, sp) + de.removeSentDiscoPingLocked(txid, sp) } -// forgetPing is called by a timer when a ping either fails to send or +// forgetDiscoPing is called by a timer when a ping either fails to send or // has taken too long to get a pong reply. -func (de *endpoint) forgetPing(txid stun.TxID) { +func (de *endpoint) forgetDiscoPing(txid stun.TxID) { de.mu.Lock() defer de.mu.Unlock() if sp, ok := de.sentPing[txid]; ok { - de.removeSentPingLocked(txid, sp) + de.removeSentDiscoPingLocked(txid, sp) } } -func (de *endpoint) removeSentPingLocked(txid stun.TxID, sp sentPing) { +func (de *endpoint) removeSentDiscoPingLocked(txid stun.TxID, sp sentPing) { // Stop the timer for the case where sendPing failed to write to UDP. // In the case of a timer already having fired, this is a no-op: sp.timer.Stop() @@ -4542,7 +4648,7 @@ func (de *endpoint) sendDiscoPing(ep netip.AddrPort, discoKey key.DiscoPublic, t NodeKey: de.c.publicKeyAtomic.Load(), }, logLevel) if !sent { - de.forgetPing(txid) + de.forgetDiscoPing(txid) } } @@ -4564,7 +4670,7 @@ func (de *endpoint) sendDiscoPing(ep netip.AddrPort, discoKey key.DiscoPublic, t pingCLI ) -func (de *endpoint) startPingLocked(ep netip.AddrPort, now mono.Time, purpose discoPingPurpose) { +func (de *endpoint) startDiscoPingLocked(ep netip.AddrPort, now mono.Time, purpose discoPingPurpose) { if runtime.GOOS == "js" { return } @@ -4587,7 +4693,7 @@ func (de *endpoint) startPingLocked(ep netip.AddrPort, now mono.Time, purpose di de.sentPing[txid] = sentPing{ to: ep, at: now, - timer: time.AfterFunc(pingTimeoutDuration, func() { de.pingTimeout(txid) }), + timer: time.AfterFunc(pingTimeoutDuration, func() { de.discoPingTimeout(txid) }), purpose: purpose, } logLevel := discoLog @@ -4597,7 +4703,7 @@ func (de *endpoint) startPingLocked(ep netip.AddrPort, now mono.Time, purpose di go de.sendDiscoPing(ep, epDisco.key, txid, logLevel) } -func (de *endpoint) sendPingsLocked(now mono.Time, sendCallMeMaybe bool) { +func (de *endpoint) sendDiscoPingsLocked(now mono.Time, sendCallMeMaybe bool) { de.lastFullPing = now var sentAny bool for ep, st := range de.endpointState { @@ -4619,7 +4725,7 @@ func (de *endpoint) sendPingsLocked(now mono.Time, sendCallMeMaybe bool) { de.c.dlogf("[v1] magicsock: disco: send, starting discovery for %v (%v)", de.publicKey.ShortString(), de.discoShort()) } - de.startPingLocked(ep, now, pingDiscovery) + de.startDiscoPingLocked(ep, now, pingDiscovery) } derpAddr := de.derpAddr if sentAny && sendCallMeMaybe && derpAddr.IsValid() { @@ -4632,9 +4738,99 @@ func (de *endpoint) sendPingsLocked(now mono.Time, sendCallMeMaybe bool) { } } +// sendWireGuardOnlyPingsLocked evaluates all available addresses for +// a WireGuard only endpoint and initates an ICMP ping for useable +// addresses. +func (de *endpoint) sendWireGuardOnlyPingsLocked(now mono.Time) { + if runtime.GOOS == "js" { + return + } + + // Normally the we only send pings at a low rate as the decision to start + // sending a ping sets bestAddrAtUntil with a reasonable time to keep trying + // that address, however, if that code changed we may want to be sure that + // we don't ever send excessive pings to avoid impact to the client/user. + if !now.After(de.lastFullPing.Add(10 * time.Second)) { + return + } + de.lastFullPing = now + + for ipp := range de.endpointState { + if ipp.Addr().Is4() && de.c.noV4.Load() { + continue + } + if ipp.Addr().Is6() && de.c.noV6.Load() { + continue + } + + go de.sendWireGuardOnlyPing(ipp, now) + } +} + +// getPinger lazily instantiates a pinger and returns it, if it was +// already instantiated it returns the existing one. +func (c *Conn) getPinger() *ping.Pinger { + return c.wgPinger.Get(func() *ping.Pinger { + return ping.New(c.connCtx, c.dlogf, netns.Listener(c.logf, c.netMon)) + }) +} + +// sendWireGuardOnlyPing sends a ICMP ping to a WireGuard only address to +// discover the latency. +func (de *endpoint) sendWireGuardOnlyPing(ipp netip.AddrPort, now mono.Time) { + ctx, cancel := context.WithTimeout(de.c.connCtx, 5*time.Second) + defer cancel() + + de.setLastPing(ipp, now) + + addr := &net.IPAddr{ + IP: net.IP(ipp.Addr().AsSlice()), + Zone: ipp.Addr().Zone(), + } + + p := de.c.getPinger() + if p == nil { + de.c.logf("[v2] magicsock: sendWireGuardOnlyPingLocked: pinger is nil") + return + } + + latency, err := p.Send(ctx, addr, nil) + if err != nil { + de.c.logf("[v2] magicsock: sendWireGuardOnlyPingLocked: %s", err) + return + } + + de.mu.Lock() + defer de.mu.Unlock() + + state, ok := de.endpointState[ipp] + if !ok { + return + } + state.addPongReplyLocked(pongReply{ + latency: latency, + pongAt: now, + from: ipp, + pongSrc: netip.AddrPort{}, // We don't know this. + }) +} + +// setLastPing sets lastPing on the endpointState to now. +func (de *endpoint) setLastPing(ipp netip.AddrPort, now mono.Time) { + de.mu.Lock() + defer de.mu.Unlock() + state, ok := de.endpointState[ipp] + if !ok { + return + } + state.lastPing = now +} + +// updateFromNode updates the endpoint based on a tailcfg.Node from a NetMap +// update. func (de *endpoint) updateFromNode(n *tailcfg.Node, heartbeatDisabled bool) { if n == nil { - panic("nil node when updating disco ep") + panic("nil node when updating endpoint") } de.mu.Lock() defer de.mu.Unlock() @@ -4642,22 +4838,6 @@ func (de *endpoint) updateFromNode(n *tailcfg.Node, heartbeatDisabled bool) { de.heartbeatDisabled = heartbeatDisabled de.expired = n.Expired - // TODO(#7826): add support for more than one endpoint for pure WireGuard - // peers, and/or support for probing "bestness" for endpoints. - if n.IsWireGuardOnly { - for _, ep := range n.Endpoints { - ipp, err := netip.ParseAddrPort(ep) - if err != nil { - de.c.logf("magicsock: invalid endpoint: %s %s", ep, err) - continue - } - de.bestAddr = addrLatency{ - AddrPort: ipp, - } - break - } - } - epDisco := de.disco.Load() var discoKey key.DiscoPublic if epDisco != nil { @@ -4810,7 +4990,7 @@ func (de *endpoint) handlePongConnLocked(m *disco.Pong, di *discoInfo, src netip return false } knownTxID = true // for naked returns below - de.removeSentPingLocked(m.TxID, sp) + de.removeSentDiscoPingLocked(m.TxID, sp) now := mono.Now() latency := now.Sub(sp.at) @@ -5026,7 +5206,7 @@ func (de *endpoint) handleCallMeMaybe(m *disco.CallMeMaybe) { for _, st := range de.endpointState { st.lastPing = 0 } - de.sendPingsLocked(mono.Now(), false) + de.sendDiscoPingsLocked(mono.Now(), false) } func (de *endpoint) populatePeerStatus(ps *ipnstate.PeerStatus) { @@ -5043,7 +5223,7 @@ func (de *endpoint) populatePeerStatus(ps *ipnstate.PeerStatus) { ps.LastWrite = de.lastSend.WallTime() ps.Active = now.Sub(de.lastSend) < sessionActiveTimeout - if udpAddr, derpAddr := de.addrForSendLocked(now); udpAddr.IsValid() && !derpAddr.IsValid() { + if udpAddr, derpAddr, _ := de.addrForSendLocked(now); udpAddr.IsValid() && !derpAddr.IsValid() { ps.CurAddr = udpAddr.String() } } @@ -5086,7 +5266,7 @@ func (de *endpoint) resetLocked() { es.lastPing = 0 } for txid, sp := range de.sentPing { - de.removeSentPingLocked(txid, sp) + de.removeSentDiscoPingLocked(txid, sp) } } diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index 9115a3eea..dee6c8d47 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -23,6 +23,7 @@ "strconv" "strings" "sync" + "sync/atomic" "testing" "time" "unsafe" @@ -33,6 +34,8 @@ "go4.org/mem" "golang.org/x/exp/maps" "golang.org/x/exp/slices" + "golang.org/x/net/icmp" + "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" "tailscale.com/cmd/testwrapper/flakytest" "tailscale.com/derp" @@ -42,11 +45,13 @@ "tailscale.com/net/connstats" "tailscale.com/net/netaddr" "tailscale.com/net/packet" + "tailscale.com/net/ping" "tailscale.com/net/stun/stuntest" "tailscale.com/net/tstun" "tailscale.com/tailcfg" "tailscale.com/tstest" "tailscale.com/tstest/natlab" + "tailscale.com/tstime/mono" "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/netlogtype" @@ -2117,9 +2122,8 @@ func Test_batchingUDPConn_coalesceMessages(t *testing.T) { } // newWireguard starts up a new wireguard-go device attached to a test tun, and -// returns the device, tun and netpoint address. To add peers call device.IpcSet -// with UAPI instructions. -func newWireguard(t *testing.T, uapi string, aips []netip.Prefix) (*device.Device, *tuntest.ChannelTUN, netip.AddrPort) { +// returns the device, tun and endpoint port. To add peers call device.IpcSet with UAPI instructions. +func newWireguard(t *testing.T, uapi string, aips []netip.Prefix) (*device.Device, *tuntest.ChannelTUN, uint16) { wgtun := tuntest.NewChannelTUN() wglogf := func(f string, args ...any) { t.Logf("wg-go: "+f, args...) @@ -2138,8 +2142,7 @@ func newWireguard(t *testing.T, uapi string, aips []netip.Prefix) (*device.Devic t.Fatal(err) } - var wgEp netip.AddrPort - + var port uint16 s, err := wgdev.IpcGet() if err != nil { t.Fatal(err) @@ -2151,17 +2154,16 @@ func newWireguard(t *testing.T, uapi string, aips []netip.Prefix) (*device.Devic } k, v, _ := strings.Cut(line, "=") if k == "listen_port" { - wgEp = netip.MustParseAddrPort("127.0.0.1:" + v) + p, err := strconv.ParseUint(v, 10, 16) + if err != nil { + panic(err) + } + port = uint16(p) break } } - if !wgEp.IsValid() { - t.Fatalf("failed to get endpoint out of wg-go") - } - t.Logf("wg-go endpoint: %s", wgEp) - - return wgdev, wgtun, wgEp + return wgdev, wgtun, port } func TestIsWireGuardOnlyPeer(t *testing.T) { @@ -2176,8 +2178,9 @@ func TestIsWireGuardOnlyPeer(t *testing.T) { uapi := fmt.Sprintf("private_key=%s\npublic_key=%s\nallowed_ip=%s\n\n", wgkey.UntypedHexString(), tskey.Public().UntypedHexString(), tsaip.String()) - wgdev, wgtun, wgEp := newWireguard(t, uapi, []netip.Prefix{wgaip}) + wgdev, wgtun, port := newWireguard(t, uapi, []netip.Prefix{wgaip}) defer wgdev.Close() + wgEp := netip.AddrPortFrom(netip.MustParseAddr("127.0.0.1"), port) m := newMagicStackWithKey(t, t.Logf, localhostListener{}, derpMap, tskey) defer m.Close() @@ -2233,8 +2236,9 @@ func TestIsWireGuardOnlyPeerWithMasquerade(t *testing.T) { uapi := fmt.Sprintf("private_key=%s\npublic_key=%s\nallowed_ip=%s\n\n", wgkey.UntypedHexString(), tskey.Public().UntypedHexString(), masqip.String()) - wgdev, wgtun, wgEp := newWireguard(t, uapi, []netip.Prefix{wgaip}) + wgdev, wgtun, port := newWireguard(t, uapi, []netip.Prefix{wgaip}) defer wgdev.Close() + wgEp := netip.AddrPortFrom(netip.MustParseAddr("127.0.0.1"), port) m := newMagicStackWithKey(t, t.Logf, localhostListener{}, derpMap, tskey) defer m.Close() @@ -2397,3 +2401,459 @@ func TestEndpointTracker(t *testing.T) { } } } + +// applyNetworkMap is a test helper that sets the network map and +// configures WG. +func applyNetworkMap(t *testing.T, m *magicStack, nm *netmap.NetworkMap) { + t.Helper() + m.conn.SetNetworkMap(nm) + // Make sure we can't use v6 to avoid test failures. + m.conn.noV6.Store(true) + + // Turn the network map into a wireguard config (for the tailscale internal wireguard device). + cfg, err := nmcfg.WGCfg(nm, t.Logf, netmap.AllowSingleHosts|netmap.AllowSubnetRoutes, "") + if err != nil { + t.Fatal(err) + } + // Apply the wireguard config to the tailscale internal wireguard device. + if err := m.Reconfig(cfg); err != nil { + t.Fatal(err) + } +} + +func TestIsWireGuardOnlyPickEndpointByPing(t *testing.T) { + clock := &tstest.Clock{} + derpMap, cleanup := runDERPAndStun(t, t.Logf, localhostListener{}, netaddr.IPv4(127, 0, 0, 1)) + defer cleanup() + + // Create a TS client. + tskey := key.NewNode() + tsaip := netip.MustParsePrefix("100.111.222.111/32") + + // Create a WireGuard only client. + wgkey := key.NewNode() + wgaip := netip.MustParsePrefix("100.222.111.222/32") + + uapi := fmt.Sprintf("private_key=%s\npublic_key=%s\nallowed_ip=%s\n\n", + wgkey.UntypedHexString(), tskey.Public().UntypedHexString(), tsaip.String()) + + wgdev, wgtun, port := newWireguard(t, uapi, []netip.Prefix{wgaip}) + defer wgdev.Close() + wgEp := netip.AddrPortFrom(netip.MustParseAddr("127.0.0.1"), port) + wgEp2 := netip.AddrPortFrom(netip.MustParseAddr("127.0.0.2"), port) + + m := newMagicStackWithKey(t, t.Logf, localhostListener{}, derpMap, tskey) + defer m.Close() + + pr := newPingResponder(t) + // Get a destination address which includes a port, so that UDP packets flow + // to the correct place, the mockPinger will use this to direct port-less + // pings to this place. + pingDest := pr.LocalAddr() + + // Create and start the pinger that is used for the + // wireguard only endpoint pings + p, closeP := mockPinger(t, clock, pingDest) + defer closeP() + m.conn.wgPinger.Set(p) + + // Create an IPv6 endpoint which should not receive any traffic. + v6, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.ParseIP("::"), Port: 0}) + if err != nil { + t.Fatal(err) + } + badEpRecv := make(chan []byte) + go func() { + defer v6.Close() + for { + b := make([]byte, 1500) + n, _, err := v6.ReadFrom(b) + if err != nil { + close(badEpRecv) + return + } + badEpRecv <- b[:n] + } + }() + wgEpV6 := netip.MustParseAddrPort(v6.LocalAddr().String()) + + nm := &netmap.NetworkMap{ + Name: "ts", + PrivateKey: m.privateKey, + NodeKey: m.privateKey.Public(), + Addresses: []netip.Prefix{tsaip}, + Peers: []*tailcfg.Node{ + { + Key: wgkey.Public(), + Endpoints: []string{wgEp.String(), wgEp2.String(), wgEpV6.String()}, + IsWireGuardOnly: true, + Addresses: []netip.Prefix{wgaip}, + AllowedIPs: []netip.Prefix{wgaip}, + }, + }, + } + + applyNetworkMap(t, m, nm) + + buf := tuntest.Ping(wgaip.Addr(), tsaip.Addr()) + m.tun.Outbound <- buf + + select { + case p := <-wgtun.Inbound: + if !bytes.Equal(p, buf) { + t.Errorf("got unexpected packet: %x", p) + } + case <-badEpRecv: + t.Fatal("got packet on bad endpoint") + case <-time.After(5 * time.Second): + t.Fatal("no packet after 1s") + } + + pi, ok := m.conn.peerMap.byNodeKey[wgkey.Public()] + if !ok { + t.Fatal("wgkey doesn't exist in peer map") + } + + // Check that we got a valid address set on the first send - this + // will be randomly selected, but because we have noV6 set to true, + // it will be the IPv4 address. + if !pi.ep.bestAddr.Addr().IsValid() { + t.Fatal("bestaddr was nil") + } + + if pi.ep.trustBestAddrUntil.Before(mono.Now().Add(14 * time.Second)) { + t.Errorf("trustBestAddrUntil time wasn't set to 15 seconds in the future: got %v", pi.ep.trustBestAddrUntil) + } + + for ipp, state := range pi.ep.endpointState { + if ipp == wgEp { + if len(state.recentPongs) != 1 { + t.Errorf("IPv4 address did not have a recentPong entry: got %v, want %v", len(state.recentPongs), 1) + } + // Set the latency extremely low so we choose this endpoint during the next + // addrForSendLocked call. + state.recentPongs[state.recentPong].latency = time.Nanosecond + } + + if ipp == wgEp2 { + if len(state.recentPongs) != 1 { + t.Errorf("IPv4 address did not have a recentPong entry: got %v, want %v", len(state.recentPongs), 1) + } + // Set the latency extremely high so we dont choose endpoint during the next + // addrForSendLocked call. + state.recentPongs[state.recentPong].latency = time.Second + } + + if ipp == wgEpV6 && len(state.recentPongs) != 0 { + t.Fatal("IPv6 should not have recentPong: IPv6 is not useable") + } + } + + // Set trustBestAddrUnitl to now, so addrForSendLocked goes through the + // latency selection flow. + pi.ep.trustBestAddrUntil = mono.Now().Add(-time.Second) + + buf = tuntest.Ping(wgaip.Addr(), tsaip.Addr()) + m.tun.Outbound <- buf + + select { + case p := <-wgtun.Inbound: + if !bytes.Equal(p, buf) { + t.Errorf("got unexpected packet: %x", p) + } + case <-badEpRecv: + t.Fatal("got packet on bad endpoint") + case <-time.After(5 * time.Second): + t.Fatal("no packet after 1s") + } + + // Check that we have responded to a WireGuard only ping twice. + if pr.responseCount != 2 { + t.Fatal("pingresponder response count was not 2", pr.responseCount) + } + + pi, ok = m.conn.peerMap.byNodeKey[wgkey.Public()] + if !ok { + t.Fatal("wgkey doesn't exist in peer map") + } + + if !pi.ep.bestAddr.Addr().IsValid() { + t.Error("no bestAddr address was set") + } + + if pi.ep.bestAddr.Addr() != wgEp.Addr() { + t.Errorf("bestAddr was not set to the expected IPv4 address: got %v, want %v", pi.ep.bestAddr.Addr().String(), wgEp.Addr()) + } + + if pi.ep.trustBestAddrUntil.IsZero() { + t.Fatal("trustBestAddrUntil was not set") + } + + if pi.ep.trustBestAddrUntil.Before(mono.Now().Add(55 * time.Minute)) { + // Set to 55 minutes incase of sloooow tests. + t.Errorf("trustBestAddrUntil time wasn't set to an hour in the future: got %v", pi.ep.trustBestAddrUntil) + } +} + +// udpingPacketConn will convert potentially ICMP destination addrs to UDP +// destination addrs in WriteTo so that a test that is intending to send ICMP +// traffic will instead send UDP traffic, without the higher level Pinger being +// aware of this difference. +type udpingPacketConn struct { + net.PacketConn + // destPort will be configured by the test to be the peer expected to respond to a ping. + destPort uint16 +} + +func (u *udpingPacketConn) WriteTo(body []byte, dest net.Addr) (int, error) { + switch d := dest.(type) { + case *net.IPAddr: + udpAddr := &net.UDPAddr{ + IP: d.IP, + Port: int(u.destPort), + Zone: d.Zone, + } + return u.PacketConn.WriteTo(body, udpAddr) + } + return 0, fmt.Errorf("unimplemented udpingPacketConn for %T", dest) +} + +type mockListenPacketer struct { + conn4 net.PacketConn + conn6 net.PacketConn +} + +func (mlp *mockListenPacketer) ListenPacket(ctx context.Context, typ string, addr string) (net.PacketConn, error) { + switch typ { + case "ip4:icmp": + return mlp.conn4, nil + case "ip6:icmp": + return mlp.conn6, nil + } + return nil, fmt.Errorf("unimplemented ListenPacketForTesting for %s", typ) +} + +func mockPinger(t *testing.T, clock *tstest.Clock, dest net.Addr) (*ping.Pinger, func()) { + ctx := context.Background() + + dIPP := netip.MustParseAddrPort(dest.String()) + // In tests, we use UDP so that we can test without being root; this + // doesn't matter because we mock out the ICMP reply below to be a real + // ICMP echo reply packet. + conn4, err := net.ListenPacket("udp4", "127.0.0.1:0") + if err != nil { + t.Fatalf("net.ListenPacket: %v", err) + } + conn6, err := net.ListenPacket("udp6", "[::]:0") + if err != nil { + t.Fatalf("net.ListenPacket: %v", err) + } + + conn4 = &udpingPacketConn{ + PacketConn: conn4, + destPort: dIPP.Port(), + } + + conn6 = &udpingPacketConn{ + PacketConn: conn6, + destPort: dIPP.Port(), + } + + p := ping.New(ctx, t.Logf, &mockListenPacketer{conn4: conn4, conn6: conn6}) + + done := func() { + if err := p.Close(); err != nil { + t.Errorf("error on close: %v", err) + } + } + + return p, done +} + +type pingResponder struct { + net.PacketConn + running atomic.Bool + responseCount int +} + +func (p *pingResponder) start() { + buf := make([]byte, 1500) + for p.running.Load() { + n, addr, err := p.PacketConn.ReadFrom(buf) + if err != nil { + return + } + + m, err := icmp.ParseMessage(1, buf[:n]) + if err != nil { + panic("got a non-ICMP message:" + fmt.Sprintf("%x", m)) + } + + r := icmp.Message{ + Type: ipv4.ICMPTypeEchoReply, + Code: m.Code, + Body: m.Body, + } + + b, err := r.Marshal(nil) + if err != nil { + panic(err) + } + + if _, err := p.PacketConn.WriteTo(b, addr); err != nil { + panic(err) + } + p.responseCount++ + } +} + +func (p *pingResponder) stop() { + p.running.Store(false) + p.Close() +} + +func newPingResponder(t *testing.T) *pingResponder { + t.Helper() + // global binds should be both IPv4 and IPv6 (if our test platforms don't, + // we might need to bind two sockets instead) + conn, err := net.ListenPacket("udp", ":") + if err != nil { + t.Fatal(err) + } + pr := &pingResponder{PacketConn: conn} + pr.running.Store(true) + go pr.start() + t.Cleanup(pr.stop) + return pr +} + +func TestAddrForSendLockedForWireGuardOnly(t *testing.T) { + testTime := mono.Now() + + type endpointDetails struct { + addrPort netip.AddrPort + latency time.Duration + } + + wgTests := []struct { + name string + noV4 bool + noV6 bool + sendWGPing bool + ep []endpointDetails + want netip.AddrPort + }{ + { + name: "choose lowest latency for useable IPv4 and IPv6", + sendWGPing: true, + ep: []endpointDetails{ + { + addrPort: netip.MustParseAddrPort("1.1.1.1:111"), + latency: 100 * time.Millisecond, + }, + { + addrPort: netip.MustParseAddrPort("[2345:0425:2CA1:0000:0000:0567:5673:23b5]:222"), + latency: 10 * time.Millisecond, + }, + }, + want: netip.MustParseAddrPort("[2345:0425:2CA1:0000:0000:0567:5673:23b5]:222"), + }, + { + name: "choose IPv4 when IPv6 is not useable", + sendWGPing: false, + noV6: true, + ep: []endpointDetails{ + { + addrPort: netip.MustParseAddrPort("1.1.1.1:111"), + latency: 100 * time.Millisecond, + }, + { + addrPort: netip.MustParseAddrPort("[1::1]:567"), + }, + }, + want: netip.MustParseAddrPort("1.1.1.1:111"), + }, + { + name: "choose IPv6 when IPv4 is not useable", + sendWGPing: false, + noV4: true, + ep: []endpointDetails{ + { + addrPort: netip.MustParseAddrPort("1.1.1.1:111"), + }, + { + addrPort: netip.MustParseAddrPort("[1::1]:567"), + latency: 100 * time.Millisecond, + }, + }, + want: netip.MustParseAddrPort("[1::1]:567"), + }, + { + name: "choose IPv6 address when latency is the same for v4 and v6", + sendWGPing: true, + ep: []endpointDetails{ + { + addrPort: netip.MustParseAddrPort("1.1.1.1:111"), + latency: 100 * time.Millisecond, + }, + { + addrPort: netip.MustParseAddrPort("[1::1]:567"), + latency: 100 * time.Millisecond, + }, + }, + want: netip.MustParseAddrPort("[1::1]:567"), + }, + } + + for _, test := range wgTests { + endpoint := &endpoint{ + isWireguardOnly: true, + endpointState: map[netip.AddrPort]*endpointState{}, + c: &Conn{ + noV4: atomic.Bool{}, + noV6: atomic.Bool{}, + }, + } + endpoint.c.noV4.Store(test.noV4) + endpoint.c.noV6.Store(test.noV6) + + for _, epd := range test.ep { + endpoint.endpointState[epd.addrPort] = &endpointState{} + } + + udpAddr, _, shouldPing := endpoint.addrForSendLocked(testTime) + if !udpAddr.IsValid() { + t.Error("udpAddr returned is not valid") + } + if shouldPing != test.sendWGPing { + t.Errorf("addrForSendLocked did not indiciate correct ping state; got %v, want %v", shouldPing, test.sendWGPing) + } + + for _, epd := range test.ep { + state, ok := endpoint.endpointState[epd.addrPort] + if !ok { + t.Errorf("addr does not exist in endpoint state map") + } + + latency, ok := state.latencyLocked() + if ok { + t.Errorf("latency was set for %v: %v", epd.addrPort, latency) + } + state.recentPongs = append(state.recentPongs, pongReply{ + latency: epd.latency, + }) + state.recentPong = 0 + } + + udpAddr, _, shouldPing = endpoint.addrForSendLocked(testTime.Add(2 * time.Minute)) + if udpAddr != test.want { + t.Errorf("udpAddr returned is not expected: got %v, want %v", udpAddr, test.want) + } + if shouldPing { + t.Error("addrForSendLocked should not indicate ping is required") + } + if endpoint.bestAddr.AddrPort != test.want { + t.Errorf("bestAddr.AddrPort is not as expected: got %v, want %v", endpoint.bestAddr.AddrPort, test.want) + } + } +}