diff --git a/cmd/tailscaled/depaware.txt b/cmd/tailscaled/depaware.txt index b51b3a3de..16c8c6bba 100644 --- a/cmd/tailscaled/depaware.txt +++ b/cmd/tailscaled/depaware.txt @@ -245,7 +245,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de 💣 tailscale.com/net/netstat from tailscale.com/ipn/ipnauth+ tailscale.com/net/netutil from tailscale.com/ipn/ipnlocal+ tailscale.com/net/packet from tailscale.com/net/tstun+ - tailscale.com/net/ping from tailscale.com/net/netcheck + tailscale.com/net/ping from tailscale.com/net/netcheck+ tailscale.com/net/portmapper from tailscale.com/net/netcheck+ tailscale.com/net/proxymux from tailscale.com/cmd/tailscaled tailscale.com/net/routetable from tailscale.com/doctor/routetable diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 95af25891..3223441c4 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -50,6 +50,7 @@ "tailscale.com/net/netmon" "tailscale.com/net/netns" "tailscale.com/net/packet" + "tailscale.com/net/ping" "tailscale.com/net/portmapper" "tailscale.com/net/sockstats" "tailscale.com/net/stun" @@ -59,6 +60,7 @@ "tailscale.com/tstime" "tailscale.com/tstime/mono" "tailscale.com/types/key" + "tailscale.com/types/lazy" "tailscale.com/types/logger" "tailscale.com/types/netmap" "tailscale.com/types/nettype" @@ -209,11 +211,16 @@ func (m *peerMap) upsertEndpoint(ep *endpoint, oldDiscoKey key.DiscoPublic) { if epDisco == nil || oldDiscoKey != epDisco.key { delete(m.nodesOfDisco[oldDiscoKey], ep.publicKey) } - if epDisco == nil { - // If the peer does not support Disco, but it does have an endpoint address, - // attempt to use that (e.g. WireGuardOnly peers). - if ep.bestAddr.AddrPort.IsValid() { - m.setNodeKeyForIPPort(ep.bestAddr.AddrPort, ep.publicKey) + if ep.isWireguardOnly { + // If the peer is a WireGuard only peer, add all of its endpoints. + + // TODO(raggi,catzkorn): this could mean that if a "isWireguardOnly" + // peer has, say, 192.168.0.2 and so does a tailscale peer, the + // wireguard one will win. That may not be the outcome that we want - + // perhaps we should prefer bestAddr.AddrPort if it is set? + // see tailscale/tailscale#7994 + for ipp := range ep.endpointState { + m.setNodeKeyForIPPort(ipp, ep.publicKey) } return @@ -473,6 +480,9 @@ type Conn struct { // peerLastDerp tracks which DERP node we last used to speak with a // peer. It's only used to quiet logging, so we only log on change. peerLastDerp map[key.NodePublic]int + + // wgPinger is the WireGuard only pinger used for latency measurements. + wgPinger lazy.SyncValue[*ping.Pinger] } // SetDebugLoggingEnabled controls whether spammy debug logging is enabled. @@ -2766,6 +2776,7 @@ func (c *Conn) SetNetworkMap(nm *netmap.NetworkMap) { sentPing: map[stun.TxID]sentPing{}, endpointState: map[netip.AddrPort]*endpointState{}, heartbeatDisabled: heartbeatDisabled, + isWireguardOnly: n.IsWireGuardOnly, } if len(n.Addresses) > 0 { ep.nodeAddr = n.Addresses[0].Addr() @@ -3143,6 +3154,11 @@ func (c *Conn) Close() error { for c.goroutinesRunningLocked() { c.muCond.Wait() } + + if pinger := c.getPinger(); pinger != nil { + pinger.Close() + } + return nil } @@ -4084,9 +4100,14 @@ type endpointDisco struct { short string // ShortString of discoKey. } -// endpoint is a wireguard/conn.Endpoint that picks the best -// available path to communicate with a peer, based on network -// conditions and what the peer supports. +// endpoint is a wireguard/conn.Endpoint. In wireguard-go and kernel WireGuard +// there is only one endpoint for a peer, but in Tailscale we distribute a +// number of possible endpoints for a peer which would include the all the +// likely addresses at which a peer may be reachable. This endpoint type holds +// the information required that when WiregGuard-Go wants to send to a +// particular peer (essentally represented by this endpoint type), the send +// function can use the currnetly best known Tailscale endpoint to send packets +// to the peer. type endpoint struct { // atomically accessed; declared first for alignment reasons lastRecv mono.Time @@ -4108,7 +4129,7 @@ type endpoint struct { heartBeatTimer *time.Timer // nil when idle lastSend mono.Time // last time there was outgoing packets sent to this peer (from wireguard-go) - lastFullPing mono.Time // last time we pinged all endpoints + lastFullPing mono.Time // last time we pinged all disco endpoints derpAddr netip.AddrPort // fallback/bootstrap path, if non-zero (non-zero for well-behaved clients) bestAddr addrLatency // best non-DERP path; zero if none @@ -4126,7 +4147,8 @@ type endpoint struct { heartbeatDisabled bool pathFinderRunning bool - expired bool // whether the node has expired + expired bool // whether the node has expired + isWireguardOnly bool // whether the endpoint is WireGuard only } type pendingCLIPing struct { @@ -4238,6 +4260,15 @@ func (st *endpointState) shouldDeleteLocked() bool { } } +// latencyLocked returns the most recent latency measurement, if any. +// endpoint.mu must be held. +func (st *endpointState) latencyLocked() (lat time.Duration, ok bool) { + if len(st.recentPongs) == 0 { + return 0, false + } + return st.recentPongs[st.recentPong].latency, true +} + func (de *endpoint) deleteEndpointLocked(why string, ep netip.AddrPort) { de.debugUpdates.Add(EndpointChange{ When: time.Now(), @@ -4321,17 +4352,87 @@ func (de *endpoint) DstToBytes() []byte { return packIPPort(de.fakeWGAddr) } // addrForSendLocked returns the address(es) that should be used for // sending the next packet. Zero, one, or both of UDP address and DERP -// addr may be non-zero. +// addr may be non-zero. If the endpoint is WireGuard only and does not have +// latency information, a bool is returned to indiciate that the +// WireGuard latency discovery pings should be sent. // // de.mu must be held. -func (de *endpoint) addrForSendLocked(now mono.Time) (udpAddr, derpAddr netip.AddrPort) { +func (de *endpoint) addrForSendLocked(now mono.Time) (udpAddr, derpAddr netip.AddrPort, sendWGPing bool) { udpAddr = de.bestAddr.AddrPort - if !udpAddr.IsValid() || now.After(de.trustBestAddrUntil) { - // We had a bestAddr but it expired so send both to it - // and DERP. - derpAddr = de.derpAddr + + if udpAddr.IsValid() && !now.After(de.trustBestAddrUntil) { + return udpAddr, netip.AddrPort{}, false } - return + + if de.isWireguardOnly { + // If the endpoint is wireguard-only, we don't have a DERP + // address to send to, so we have to send to the UDP address. + udpAddr, shouldPing := de.addrForWireGuardSendLocked(now) + return udpAddr, netip.AddrPort{}, shouldPing + } + + // We had a bestAddr but it expired so send both to it + // and DERP. + return udpAddr, de.derpAddr, false +} + +// addrForWireGuardSendLocked returns the address that should be used for +// sending the next packet. If a packet has never or not recently been sent to +// the endpoint, then a randomly selected address for the endpoint is returned, +// as well as a bool indiciating that WireGuard discovery pings should be started. +// If the addresses have latency information available, then the address with the +// best latency is used. +// +// de.mu must be held. +func (de *endpoint) addrForWireGuardSendLocked(now mono.Time) (udpAddr netip.AddrPort, shouldPing bool) { + // lowestLatency is a high duration initially, so we + // can be sure we're going to have a duration lower than this + // for the first latency retrieved. + lowestLatency := time.Hour + for ipp, state := range de.endpointState { + if latency, ok := state.latencyLocked(); ok { + if latency < lowestLatency || latency == lowestLatency && ipp.Addr().Is6() { + // If we have the same latency,IPv6 is prioritized. + // TODO(catzkorn): Consider a small increase in latency to use + // IPv6 in comparison to IPv4, when possible. + lowestLatency = latency + udpAddr = ipp + } + } + } + + if udpAddr.IsValid() { + // Set trustBestAddrUntil to an hour, so we will + // continue to use this address for a long period of time. + de.bestAddr.AddrPort = udpAddr + de.trustBestAddrUntil = now.Add(1 * time.Hour) + return udpAddr, false + } + + candidates := make([]netip.AddrPort, 0, len(de.endpointState)) + for ipp := range de.endpointState { + if ipp.Addr().Is4() && de.c.noV4.Load() { + continue + } + if ipp.Addr().Is6() && de.c.noV6.Load() { + continue + } + candidates = append(candidates, ipp) + } + // Randomly select an address to use until we retrieve latency information + // and give it a short trustBestAddrUntil time so we avoid flapping between + // addresses while waiting on latency information to be populated. + udpAddr = candidates[rand.Intn(len(candidates))] + de.bestAddr.AddrPort = udpAddr + if len(candidates) == 1 { + // if we only have one address that we can send data too, + // we should trust it for a longer period of time. + de.trustBestAddrUntil = now.Add(1 * time.Hour) + } else { + de.trustBestAddrUntil = now.Add(15 * time.Second) + } + + return udpAddr, len(candidates) > 1 } // heartbeat is called every heartbeatInterval to keep the best UDP path alive, @@ -4359,14 +4460,14 @@ func (de *endpoint) heartbeat() { } now := mono.Now() - udpAddr, _ := de.addrForSendLocked(now) + udpAddr, _, _ := de.addrForSendLocked(now) if udpAddr.IsValid() { // We have a preferred path. Ping that every 2 seconds. - de.startPingLocked(udpAddr, now, pingHeartbeat) + de.startDiscoPingLocked(udpAddr, now, pingHeartbeat) } if de.wantFullPingLocked(now) { - de.sendPingsLocked(now, true) + de.sendDiscoPingsLocked(now, true) } de.heartBeatTimer = time.AfterFunc(heartbeatInterval, de.heartbeat) @@ -4417,19 +4518,19 @@ func (de *endpoint) cliPing(res *ipnstate.PingResult, cb func(*ipnstate.PingResu de.pendingCLIPings = append(de.pendingCLIPings, pendingCLIPing{res, cb}) now := mono.Now() - udpAddr, derpAddr := de.addrForSendLocked(now) + udpAddr, derpAddr, _ := de.addrForSendLocked(now) if derpAddr.IsValid() { - de.startPingLocked(derpAddr, now, pingCLI) + de.startDiscoPingLocked(derpAddr, now, pingCLI) } if udpAddr.IsValid() && now.Before(de.trustBestAddrUntil) { // Already have an active session, so just ping the address we're using. // Otherwise "tailscale ping" results to a node on the local network // can look like they're bouncing between, say 10.0.0.0/9 and the peer's // IPv6 address, both 1ms away, and it's random who replies first. - de.startPingLocked(udpAddr, now, pingCLI) + de.startDiscoPingLocked(udpAddr, now, pingCLI) } else { for ep := range de.endpointState { - de.startPingLocked(ep, now, pingCLI) + de.startDiscoPingLocked(ep, now, pingCLI) } } de.noteActiveLocked() @@ -4459,9 +4560,14 @@ func (de *endpoint) send(buffs [][]byte) error { } now := mono.Now() - udpAddr, derpAddr := de.addrForSendLocked(now) - if !udpAddr.IsValid() || now.After(de.trustBestAddrUntil) { - de.sendPingsLocked(now, true) + udpAddr, derpAddr, startWGPing := de.addrForSendLocked(now) + + if de.isWireguardOnly { + if startWGPing { + de.sendWireGuardOnlyPingsLocked(now) + } + } else if !udpAddr.IsValid() || now.After(de.trustBestAddrUntil) { + de.sendDiscoPingsLocked(now, true) } de.noteActiveLocked() de.mu.Unlock() @@ -4499,7 +4605,7 @@ func (de *endpoint) send(buffs [][]byte) error { return err } -func (de *endpoint) pingTimeout(txid stun.TxID) { +func (de *endpoint) discoPingTimeout(txid stun.TxID) { de.mu.Lock() defer de.mu.Unlock() sp, ok := de.sentPing[txid] @@ -4509,20 +4615,20 @@ func (de *endpoint) pingTimeout(txid stun.TxID) { if debugDisco() || !de.bestAddr.IsValid() || mono.Now().After(de.trustBestAddrUntil) { de.c.dlogf("[v1] magicsock: disco: timeout waiting for pong %x from %v (%v, %v)", txid[:6], sp.to, de.publicKey.ShortString(), de.discoShort()) } - de.removeSentPingLocked(txid, sp) + de.removeSentDiscoPingLocked(txid, sp) } -// forgetPing is called by a timer when a ping either fails to send or +// forgetDiscoPing is called by a timer when a ping either fails to send or // has taken too long to get a pong reply. -func (de *endpoint) forgetPing(txid stun.TxID) { +func (de *endpoint) forgetDiscoPing(txid stun.TxID) { de.mu.Lock() defer de.mu.Unlock() if sp, ok := de.sentPing[txid]; ok { - de.removeSentPingLocked(txid, sp) + de.removeSentDiscoPingLocked(txid, sp) } } -func (de *endpoint) removeSentPingLocked(txid stun.TxID, sp sentPing) { +func (de *endpoint) removeSentDiscoPingLocked(txid stun.TxID, sp sentPing) { // Stop the timer for the case where sendPing failed to write to UDP. // In the case of a timer already having fired, this is a no-op: sp.timer.Stop() @@ -4542,7 +4648,7 @@ func (de *endpoint) sendDiscoPing(ep netip.AddrPort, discoKey key.DiscoPublic, t NodeKey: de.c.publicKeyAtomic.Load(), }, logLevel) if !sent { - de.forgetPing(txid) + de.forgetDiscoPing(txid) } } @@ -4564,7 +4670,7 @@ func (de *endpoint) sendDiscoPing(ep netip.AddrPort, discoKey key.DiscoPublic, t pingCLI ) -func (de *endpoint) startPingLocked(ep netip.AddrPort, now mono.Time, purpose discoPingPurpose) { +func (de *endpoint) startDiscoPingLocked(ep netip.AddrPort, now mono.Time, purpose discoPingPurpose) { if runtime.GOOS == "js" { return } @@ -4587,7 +4693,7 @@ func (de *endpoint) startPingLocked(ep netip.AddrPort, now mono.Time, purpose di de.sentPing[txid] = sentPing{ to: ep, at: now, - timer: time.AfterFunc(pingTimeoutDuration, func() { de.pingTimeout(txid) }), + timer: time.AfterFunc(pingTimeoutDuration, func() { de.discoPingTimeout(txid) }), purpose: purpose, } logLevel := discoLog @@ -4597,7 +4703,7 @@ func (de *endpoint) startPingLocked(ep netip.AddrPort, now mono.Time, purpose di go de.sendDiscoPing(ep, epDisco.key, txid, logLevel) } -func (de *endpoint) sendPingsLocked(now mono.Time, sendCallMeMaybe bool) { +func (de *endpoint) sendDiscoPingsLocked(now mono.Time, sendCallMeMaybe bool) { de.lastFullPing = now var sentAny bool for ep, st := range de.endpointState { @@ -4619,7 +4725,7 @@ func (de *endpoint) sendPingsLocked(now mono.Time, sendCallMeMaybe bool) { de.c.dlogf("[v1] magicsock: disco: send, starting discovery for %v (%v)", de.publicKey.ShortString(), de.discoShort()) } - de.startPingLocked(ep, now, pingDiscovery) + de.startDiscoPingLocked(ep, now, pingDiscovery) } derpAddr := de.derpAddr if sentAny && sendCallMeMaybe && derpAddr.IsValid() { @@ -4632,9 +4738,99 @@ func (de *endpoint) sendPingsLocked(now mono.Time, sendCallMeMaybe bool) { } } +// sendWireGuardOnlyPingsLocked evaluates all available addresses for +// a WireGuard only endpoint and initates an ICMP ping for useable +// addresses. +func (de *endpoint) sendWireGuardOnlyPingsLocked(now mono.Time) { + if runtime.GOOS == "js" { + return + } + + // Normally the we only send pings at a low rate as the decision to start + // sending a ping sets bestAddrAtUntil with a reasonable time to keep trying + // that address, however, if that code changed we may want to be sure that + // we don't ever send excessive pings to avoid impact to the client/user. + if !now.After(de.lastFullPing.Add(10 * time.Second)) { + return + } + de.lastFullPing = now + + for ipp := range de.endpointState { + if ipp.Addr().Is4() && de.c.noV4.Load() { + continue + } + if ipp.Addr().Is6() && de.c.noV6.Load() { + continue + } + + go de.sendWireGuardOnlyPing(ipp, now) + } +} + +// getPinger lazily instantiates a pinger and returns it, if it was +// already instantiated it returns the existing one. +func (c *Conn) getPinger() *ping.Pinger { + return c.wgPinger.Get(func() *ping.Pinger { + return ping.New(c.connCtx, c.dlogf, netns.Listener(c.logf, c.netMon)) + }) +} + +// sendWireGuardOnlyPing sends a ICMP ping to a WireGuard only address to +// discover the latency. +func (de *endpoint) sendWireGuardOnlyPing(ipp netip.AddrPort, now mono.Time) { + ctx, cancel := context.WithTimeout(de.c.connCtx, 5*time.Second) + defer cancel() + + de.setLastPing(ipp, now) + + addr := &net.IPAddr{ + IP: net.IP(ipp.Addr().AsSlice()), + Zone: ipp.Addr().Zone(), + } + + p := de.c.getPinger() + if p == nil { + de.c.logf("[v2] magicsock: sendWireGuardOnlyPingLocked: pinger is nil") + return + } + + latency, err := p.Send(ctx, addr, nil) + if err != nil { + de.c.logf("[v2] magicsock: sendWireGuardOnlyPingLocked: %s", err) + return + } + + de.mu.Lock() + defer de.mu.Unlock() + + state, ok := de.endpointState[ipp] + if !ok { + return + } + state.addPongReplyLocked(pongReply{ + latency: latency, + pongAt: now, + from: ipp, + pongSrc: netip.AddrPort{}, // We don't know this. + }) +} + +// setLastPing sets lastPing on the endpointState to now. +func (de *endpoint) setLastPing(ipp netip.AddrPort, now mono.Time) { + de.mu.Lock() + defer de.mu.Unlock() + state, ok := de.endpointState[ipp] + if !ok { + return + } + state.lastPing = now +} + +// updateFromNode updates the endpoint based on a tailcfg.Node from a NetMap +// update. func (de *endpoint) updateFromNode(n *tailcfg.Node, heartbeatDisabled bool) { if n == nil { - panic("nil node when updating disco ep") + panic("nil node when updating endpoint") } de.mu.Lock() defer de.mu.Unlock() @@ -4642,22 +4838,6 @@ func (de *endpoint) updateFromNode(n *tailcfg.Node, heartbeatDisabled bool) { de.heartbeatDisabled = heartbeatDisabled de.expired = n.Expired - // TODO(#7826): add support for more than one endpoint for pure WireGuard - // peers, and/or support for probing "bestness" for endpoints. - if n.IsWireGuardOnly { - for _, ep := range n.Endpoints { - ipp, err := netip.ParseAddrPort(ep) - if err != nil { - de.c.logf("magicsock: invalid endpoint: %s %s", ep, err) - continue - } - de.bestAddr = addrLatency{ - AddrPort: ipp, - } - break - } - } - epDisco := de.disco.Load() var discoKey key.DiscoPublic if epDisco != nil { @@ -4810,7 +4990,7 @@ func (de *endpoint) handlePongConnLocked(m *disco.Pong, di *discoInfo, src netip return false } knownTxID = true // for naked returns below - de.removeSentPingLocked(m.TxID, sp) + de.removeSentDiscoPingLocked(m.TxID, sp) now := mono.Now() latency := now.Sub(sp.at) @@ -5026,7 +5206,7 @@ func (de *endpoint) handleCallMeMaybe(m *disco.CallMeMaybe) { for _, st := range de.endpointState { st.lastPing = 0 } - de.sendPingsLocked(mono.Now(), false) + de.sendDiscoPingsLocked(mono.Now(), false) } func (de *endpoint) populatePeerStatus(ps *ipnstate.PeerStatus) { @@ -5043,7 +5223,7 @@ func (de *endpoint) populatePeerStatus(ps *ipnstate.PeerStatus) { ps.LastWrite = de.lastSend.WallTime() ps.Active = now.Sub(de.lastSend) < sessionActiveTimeout - if udpAddr, derpAddr := de.addrForSendLocked(now); udpAddr.IsValid() && !derpAddr.IsValid() { + if udpAddr, derpAddr, _ := de.addrForSendLocked(now); udpAddr.IsValid() && !derpAddr.IsValid() { ps.CurAddr = udpAddr.String() } } @@ -5086,7 +5266,7 @@ func (de *endpoint) resetLocked() { es.lastPing = 0 } for txid, sp := range de.sentPing { - de.removeSentPingLocked(txid, sp) + de.removeSentDiscoPingLocked(txid, sp) } } diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index 9115a3eea..dee6c8d47 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -23,6 +23,7 @@ "strconv" "strings" "sync" + "sync/atomic" "testing" "time" "unsafe" @@ -33,6 +34,8 @@ "go4.org/mem" "golang.org/x/exp/maps" "golang.org/x/exp/slices" + "golang.org/x/net/icmp" + "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" "tailscale.com/cmd/testwrapper/flakytest" "tailscale.com/derp" @@ -42,11 +45,13 @@ "tailscale.com/net/connstats" "tailscale.com/net/netaddr" "tailscale.com/net/packet" + "tailscale.com/net/ping" "tailscale.com/net/stun/stuntest" "tailscale.com/net/tstun" "tailscale.com/tailcfg" "tailscale.com/tstest" "tailscale.com/tstest/natlab" + "tailscale.com/tstime/mono" "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/netlogtype" @@ -2117,9 +2122,8 @@ func Test_batchingUDPConn_coalesceMessages(t *testing.T) { } // newWireguard starts up a new wireguard-go device attached to a test tun, and -// returns the device, tun and netpoint address. To add peers call device.IpcSet -// with UAPI instructions. -func newWireguard(t *testing.T, uapi string, aips []netip.Prefix) (*device.Device, *tuntest.ChannelTUN, netip.AddrPort) { +// returns the device, tun and endpoint port. To add peers call device.IpcSet with UAPI instructions. +func newWireguard(t *testing.T, uapi string, aips []netip.Prefix) (*device.Device, *tuntest.ChannelTUN, uint16) { wgtun := tuntest.NewChannelTUN() wglogf := func(f string, args ...any) { t.Logf("wg-go: "+f, args...) @@ -2138,8 +2142,7 @@ func newWireguard(t *testing.T, uapi string, aips []netip.Prefix) (*device.Devic t.Fatal(err) } - var wgEp netip.AddrPort - + var port uint16 s, err := wgdev.IpcGet() if err != nil { t.Fatal(err) @@ -2151,17 +2154,16 @@ func newWireguard(t *testing.T, uapi string, aips []netip.Prefix) (*device.Devic } k, v, _ := strings.Cut(line, "=") if k == "listen_port" { - wgEp = netip.MustParseAddrPort("127.0.0.1:" + v) + p, err := strconv.ParseUint(v, 10, 16) + if err != nil { + panic(err) + } + port = uint16(p) break } } - if !wgEp.IsValid() { - t.Fatalf("failed to get endpoint out of wg-go") - } - t.Logf("wg-go endpoint: %s", wgEp) - - return wgdev, wgtun, wgEp + return wgdev, wgtun, port } func TestIsWireGuardOnlyPeer(t *testing.T) { @@ -2176,8 +2178,9 @@ func TestIsWireGuardOnlyPeer(t *testing.T) { uapi := fmt.Sprintf("private_key=%s\npublic_key=%s\nallowed_ip=%s\n\n", wgkey.UntypedHexString(), tskey.Public().UntypedHexString(), tsaip.String()) - wgdev, wgtun, wgEp := newWireguard(t, uapi, []netip.Prefix{wgaip}) + wgdev, wgtun, port := newWireguard(t, uapi, []netip.Prefix{wgaip}) defer wgdev.Close() + wgEp := netip.AddrPortFrom(netip.MustParseAddr("127.0.0.1"), port) m := newMagicStackWithKey(t, t.Logf, localhostListener{}, derpMap, tskey) defer m.Close() @@ -2233,8 +2236,9 @@ func TestIsWireGuardOnlyPeerWithMasquerade(t *testing.T) { uapi := fmt.Sprintf("private_key=%s\npublic_key=%s\nallowed_ip=%s\n\n", wgkey.UntypedHexString(), tskey.Public().UntypedHexString(), masqip.String()) - wgdev, wgtun, wgEp := newWireguard(t, uapi, []netip.Prefix{wgaip}) + wgdev, wgtun, port := newWireguard(t, uapi, []netip.Prefix{wgaip}) defer wgdev.Close() + wgEp := netip.AddrPortFrom(netip.MustParseAddr("127.0.0.1"), port) m := newMagicStackWithKey(t, t.Logf, localhostListener{}, derpMap, tskey) defer m.Close() @@ -2397,3 +2401,459 @@ func TestEndpointTracker(t *testing.T) { } } } + +// applyNetworkMap is a test helper that sets the network map and +// configures WG. +func applyNetworkMap(t *testing.T, m *magicStack, nm *netmap.NetworkMap) { + t.Helper() + m.conn.SetNetworkMap(nm) + // Make sure we can't use v6 to avoid test failures. + m.conn.noV6.Store(true) + + // Turn the network map into a wireguard config (for the tailscale internal wireguard device). + cfg, err := nmcfg.WGCfg(nm, t.Logf, netmap.AllowSingleHosts|netmap.AllowSubnetRoutes, "") + if err != nil { + t.Fatal(err) + } + // Apply the wireguard config to the tailscale internal wireguard device. + if err := m.Reconfig(cfg); err != nil { + t.Fatal(err) + } +} + +func TestIsWireGuardOnlyPickEndpointByPing(t *testing.T) { + clock := &tstest.Clock{} + derpMap, cleanup := runDERPAndStun(t, t.Logf, localhostListener{}, netaddr.IPv4(127, 0, 0, 1)) + defer cleanup() + + // Create a TS client. + tskey := key.NewNode() + tsaip := netip.MustParsePrefix("100.111.222.111/32") + + // Create a WireGuard only client. + wgkey := key.NewNode() + wgaip := netip.MustParsePrefix("100.222.111.222/32") + + uapi := fmt.Sprintf("private_key=%s\npublic_key=%s\nallowed_ip=%s\n\n", + wgkey.UntypedHexString(), tskey.Public().UntypedHexString(), tsaip.String()) + + wgdev, wgtun, port := newWireguard(t, uapi, []netip.Prefix{wgaip}) + defer wgdev.Close() + wgEp := netip.AddrPortFrom(netip.MustParseAddr("127.0.0.1"), port) + wgEp2 := netip.AddrPortFrom(netip.MustParseAddr("127.0.0.2"), port) + + m := newMagicStackWithKey(t, t.Logf, localhostListener{}, derpMap, tskey) + defer m.Close() + + pr := newPingResponder(t) + // Get a destination address which includes a port, so that UDP packets flow + // to the correct place, the mockPinger will use this to direct port-less + // pings to this place. + pingDest := pr.LocalAddr() + + // Create and start the pinger that is used for the + // wireguard only endpoint pings + p, closeP := mockPinger(t, clock, pingDest) + defer closeP() + m.conn.wgPinger.Set(p) + + // Create an IPv6 endpoint which should not receive any traffic. + v6, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.ParseIP("::"), Port: 0}) + if err != nil { + t.Fatal(err) + } + badEpRecv := make(chan []byte) + go func() { + defer v6.Close() + for { + b := make([]byte, 1500) + n, _, err := v6.ReadFrom(b) + if err != nil { + close(badEpRecv) + return + } + badEpRecv <- b[:n] + } + }() + wgEpV6 := netip.MustParseAddrPort(v6.LocalAddr().String()) + + nm := &netmap.NetworkMap{ + Name: "ts", + PrivateKey: m.privateKey, + NodeKey: m.privateKey.Public(), + Addresses: []netip.Prefix{tsaip}, + Peers: []*tailcfg.Node{ + { + Key: wgkey.Public(), + Endpoints: []string{wgEp.String(), wgEp2.String(), wgEpV6.String()}, + IsWireGuardOnly: true, + Addresses: []netip.Prefix{wgaip}, + AllowedIPs: []netip.Prefix{wgaip}, + }, + }, + } + + applyNetworkMap(t, m, nm) + + buf := tuntest.Ping(wgaip.Addr(), tsaip.Addr()) + m.tun.Outbound <- buf + + select { + case p := <-wgtun.Inbound: + if !bytes.Equal(p, buf) { + t.Errorf("got unexpected packet: %x", p) + } + case <-badEpRecv: + t.Fatal("got packet on bad endpoint") + case <-time.After(5 * time.Second): + t.Fatal("no packet after 1s") + } + + pi, ok := m.conn.peerMap.byNodeKey[wgkey.Public()] + if !ok { + t.Fatal("wgkey doesn't exist in peer map") + } + + // Check that we got a valid address set on the first send - this + // will be randomly selected, but because we have noV6 set to true, + // it will be the IPv4 address. + if !pi.ep.bestAddr.Addr().IsValid() { + t.Fatal("bestaddr was nil") + } + + if pi.ep.trustBestAddrUntil.Before(mono.Now().Add(14 * time.Second)) { + t.Errorf("trustBestAddrUntil time wasn't set to 15 seconds in the future: got %v", pi.ep.trustBestAddrUntil) + } + + for ipp, state := range pi.ep.endpointState { + if ipp == wgEp { + if len(state.recentPongs) != 1 { + t.Errorf("IPv4 address did not have a recentPong entry: got %v, want %v", len(state.recentPongs), 1) + } + // Set the latency extremely low so we choose this endpoint during the next + // addrForSendLocked call. + state.recentPongs[state.recentPong].latency = time.Nanosecond + } + + if ipp == wgEp2 { + if len(state.recentPongs) != 1 { + t.Errorf("IPv4 address did not have a recentPong entry: got %v, want %v", len(state.recentPongs), 1) + } + // Set the latency extremely high so we dont choose endpoint during the next + // addrForSendLocked call. + state.recentPongs[state.recentPong].latency = time.Second + } + + if ipp == wgEpV6 && len(state.recentPongs) != 0 { + t.Fatal("IPv6 should not have recentPong: IPv6 is not useable") + } + } + + // Set trustBestAddrUnitl to now, so addrForSendLocked goes through the + // latency selection flow. + pi.ep.trustBestAddrUntil = mono.Now().Add(-time.Second) + + buf = tuntest.Ping(wgaip.Addr(), tsaip.Addr()) + m.tun.Outbound <- buf + + select { + case p := <-wgtun.Inbound: + if !bytes.Equal(p, buf) { + t.Errorf("got unexpected packet: %x", p) + } + case <-badEpRecv: + t.Fatal("got packet on bad endpoint") + case <-time.After(5 * time.Second): + t.Fatal("no packet after 1s") + } + + // Check that we have responded to a WireGuard only ping twice. + if pr.responseCount != 2 { + t.Fatal("pingresponder response count was not 2", pr.responseCount) + } + + pi, ok = m.conn.peerMap.byNodeKey[wgkey.Public()] + if !ok { + t.Fatal("wgkey doesn't exist in peer map") + } + + if !pi.ep.bestAddr.Addr().IsValid() { + t.Error("no bestAddr address was set") + } + + if pi.ep.bestAddr.Addr() != wgEp.Addr() { + t.Errorf("bestAddr was not set to the expected IPv4 address: got %v, want %v", pi.ep.bestAddr.Addr().String(), wgEp.Addr()) + } + + if pi.ep.trustBestAddrUntil.IsZero() { + t.Fatal("trustBestAddrUntil was not set") + } + + if pi.ep.trustBestAddrUntil.Before(mono.Now().Add(55 * time.Minute)) { + // Set to 55 minutes incase of sloooow tests. + t.Errorf("trustBestAddrUntil time wasn't set to an hour in the future: got %v", pi.ep.trustBestAddrUntil) + } +} + +// udpingPacketConn will convert potentially ICMP destination addrs to UDP +// destination addrs in WriteTo so that a test that is intending to send ICMP +// traffic will instead send UDP traffic, without the higher level Pinger being +// aware of this difference. +type udpingPacketConn struct { + net.PacketConn + // destPort will be configured by the test to be the peer expected to respond to a ping. + destPort uint16 +} + +func (u *udpingPacketConn) WriteTo(body []byte, dest net.Addr) (int, error) { + switch d := dest.(type) { + case *net.IPAddr: + udpAddr := &net.UDPAddr{ + IP: d.IP, + Port: int(u.destPort), + Zone: d.Zone, + } + return u.PacketConn.WriteTo(body, udpAddr) + } + return 0, fmt.Errorf("unimplemented udpingPacketConn for %T", dest) +} + +type mockListenPacketer struct { + conn4 net.PacketConn + conn6 net.PacketConn +} + +func (mlp *mockListenPacketer) ListenPacket(ctx context.Context, typ string, addr string) (net.PacketConn, error) { + switch typ { + case "ip4:icmp": + return mlp.conn4, nil + case "ip6:icmp": + return mlp.conn6, nil + } + return nil, fmt.Errorf("unimplemented ListenPacketForTesting for %s", typ) +} + +func mockPinger(t *testing.T, clock *tstest.Clock, dest net.Addr) (*ping.Pinger, func()) { + ctx := context.Background() + + dIPP := netip.MustParseAddrPort(dest.String()) + // In tests, we use UDP so that we can test without being root; this + // doesn't matter because we mock out the ICMP reply below to be a real + // ICMP echo reply packet. + conn4, err := net.ListenPacket("udp4", "127.0.0.1:0") + if err != nil { + t.Fatalf("net.ListenPacket: %v", err) + } + conn6, err := net.ListenPacket("udp6", "[::]:0") + if err != nil { + t.Fatalf("net.ListenPacket: %v", err) + } + + conn4 = &udpingPacketConn{ + PacketConn: conn4, + destPort: dIPP.Port(), + } + + conn6 = &udpingPacketConn{ + PacketConn: conn6, + destPort: dIPP.Port(), + } + + p := ping.New(ctx, t.Logf, &mockListenPacketer{conn4: conn4, conn6: conn6}) + + done := func() { + if err := p.Close(); err != nil { + t.Errorf("error on close: %v", err) + } + } + + return p, done +} + +type pingResponder struct { + net.PacketConn + running atomic.Bool + responseCount int +} + +func (p *pingResponder) start() { + buf := make([]byte, 1500) + for p.running.Load() { + n, addr, err := p.PacketConn.ReadFrom(buf) + if err != nil { + return + } + + m, err := icmp.ParseMessage(1, buf[:n]) + if err != nil { + panic("got a non-ICMP message:" + fmt.Sprintf("%x", m)) + } + + r := icmp.Message{ + Type: ipv4.ICMPTypeEchoReply, + Code: m.Code, + Body: m.Body, + } + + b, err := r.Marshal(nil) + if err != nil { + panic(err) + } + + if _, err := p.PacketConn.WriteTo(b, addr); err != nil { + panic(err) + } + p.responseCount++ + } +} + +func (p *pingResponder) stop() { + p.running.Store(false) + p.Close() +} + +func newPingResponder(t *testing.T) *pingResponder { + t.Helper() + // global binds should be both IPv4 and IPv6 (if our test platforms don't, + // we might need to bind two sockets instead) + conn, err := net.ListenPacket("udp", ":") + if err != nil { + t.Fatal(err) + } + pr := &pingResponder{PacketConn: conn} + pr.running.Store(true) + go pr.start() + t.Cleanup(pr.stop) + return pr +} + +func TestAddrForSendLockedForWireGuardOnly(t *testing.T) { + testTime := mono.Now() + + type endpointDetails struct { + addrPort netip.AddrPort + latency time.Duration + } + + wgTests := []struct { + name string + noV4 bool + noV6 bool + sendWGPing bool + ep []endpointDetails + want netip.AddrPort + }{ + { + name: "choose lowest latency for useable IPv4 and IPv6", + sendWGPing: true, + ep: []endpointDetails{ + { + addrPort: netip.MustParseAddrPort("1.1.1.1:111"), + latency: 100 * time.Millisecond, + }, + { + addrPort: netip.MustParseAddrPort("[2345:0425:2CA1:0000:0000:0567:5673:23b5]:222"), + latency: 10 * time.Millisecond, + }, + }, + want: netip.MustParseAddrPort("[2345:0425:2CA1:0000:0000:0567:5673:23b5]:222"), + }, + { + name: "choose IPv4 when IPv6 is not useable", + sendWGPing: false, + noV6: true, + ep: []endpointDetails{ + { + addrPort: netip.MustParseAddrPort("1.1.1.1:111"), + latency: 100 * time.Millisecond, + }, + { + addrPort: netip.MustParseAddrPort("[1::1]:567"), + }, + }, + want: netip.MustParseAddrPort("1.1.1.1:111"), + }, + { + name: "choose IPv6 when IPv4 is not useable", + sendWGPing: false, + noV4: true, + ep: []endpointDetails{ + { + addrPort: netip.MustParseAddrPort("1.1.1.1:111"), + }, + { + addrPort: netip.MustParseAddrPort("[1::1]:567"), + latency: 100 * time.Millisecond, + }, + }, + want: netip.MustParseAddrPort("[1::1]:567"), + }, + { + name: "choose IPv6 address when latency is the same for v4 and v6", + sendWGPing: true, + ep: []endpointDetails{ + { + addrPort: netip.MustParseAddrPort("1.1.1.1:111"), + latency: 100 * time.Millisecond, + }, + { + addrPort: netip.MustParseAddrPort("[1::1]:567"), + latency: 100 * time.Millisecond, + }, + }, + want: netip.MustParseAddrPort("[1::1]:567"), + }, + } + + for _, test := range wgTests { + endpoint := &endpoint{ + isWireguardOnly: true, + endpointState: map[netip.AddrPort]*endpointState{}, + c: &Conn{ + noV4: atomic.Bool{}, + noV6: atomic.Bool{}, + }, + } + endpoint.c.noV4.Store(test.noV4) + endpoint.c.noV6.Store(test.noV6) + + for _, epd := range test.ep { + endpoint.endpointState[epd.addrPort] = &endpointState{} + } + + udpAddr, _, shouldPing := endpoint.addrForSendLocked(testTime) + if !udpAddr.IsValid() { + t.Error("udpAddr returned is not valid") + } + if shouldPing != test.sendWGPing { + t.Errorf("addrForSendLocked did not indiciate correct ping state; got %v, want %v", shouldPing, test.sendWGPing) + } + + for _, epd := range test.ep { + state, ok := endpoint.endpointState[epd.addrPort] + if !ok { + t.Errorf("addr does not exist in endpoint state map") + } + + latency, ok := state.latencyLocked() + if ok { + t.Errorf("latency was set for %v: %v", epd.addrPort, latency) + } + state.recentPongs = append(state.recentPongs, pongReply{ + latency: epd.latency, + }) + state.recentPong = 0 + } + + udpAddr, _, shouldPing = endpoint.addrForSendLocked(testTime.Add(2 * time.Minute)) + if udpAddr != test.want { + t.Errorf("udpAddr returned is not expected: got %v, want %v", udpAddr, test.want) + } + if shouldPing { + t.Error("addrForSendLocked should not indicate ping is required") + } + if endpoint.bestAddr.AddrPort != test.want { + t.Errorf("bestAddr.AddrPort is not as expected: got %v, want %v", endpoint.bestAddr.AddrPort, test.want) + } + } +}