Jordan Whited cf253057f7
wgengine/magicsock: always use Cryptokey Routing identification
We only set [epAddr]s in the [peerMap] when wireguard-go tells us who
they belong to. A node key can only have a single [epAddr] in the
[peerMap].

We also clear an [epAddr] when wireguard-go tells us our
mapping assumption between [epAddr] and peer is wrong (outdated).

Signed-off-by: Jordan Whited <jordan@tailscale.com>
2025-07-02 19:31:03 -07:00

212 lines
6.1 KiB
Go

// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package magicsock
import (
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"tailscale.com/util/set"
)
// peerInfo is all the information magicsock tracks about a particular
// peer.
type peerInfo struct {
ep *endpoint // always non-nil.
// epAddrs is an inverted version of peerMap.byEpAddr (below), so
// that when we're deleting this node, we can rapidly find out the
// keys that need deleting from peerMap.byEpAddr without having to
// iterate over every epAddr known for any peer.
epAddr epAddr
}
func newPeerInfo(ep *endpoint) *peerInfo {
return &peerInfo{
ep: ep,
}
}
// peerMap is an index of peerInfos by node (WireGuard) key, disco
// key, and discovered ip:port endpoints.
//
// It doesn't do any locking; all access must be done with Conn.mu held.
type peerMap struct {
byNodeKey map[key.NodePublic]*peerInfo
byEpAddr map[epAddr]*peerInfo
byNodeID map[tailcfg.NodeID]*peerInfo
// nodesOfDisco contains the set of nodes that are using a
// DiscoKey. Usually those sets will be just one node.
nodesOfDisco map[key.DiscoPublic]set.Set[key.NodePublic]
}
func newPeerMap() peerMap {
return peerMap{
byNodeKey: map[key.NodePublic]*peerInfo{},
byEpAddr: map[epAddr]*peerInfo{},
byNodeID: map[tailcfg.NodeID]*peerInfo{},
nodesOfDisco: map[key.DiscoPublic]set.Set[key.NodePublic]{},
}
}
// nodeCount returns the number of nodes currently in m.
func (m *peerMap) nodeCount() int {
if len(m.byNodeKey) != len(m.byNodeID) {
devPanicf("internal error: peerMap.byNodeKey and byNodeID out of sync")
}
return len(m.byNodeKey)
}
// knownPeerDiscoKey reports whether there exists any peer with the disco key
// dk.
func (m *peerMap) knownPeerDiscoKey(dk key.DiscoPublic) bool {
_, ok := m.nodesOfDisco[dk]
return ok
}
// endpointForNodeKey returns the endpoint for nk, or nil if
// nk is not known to us.
func (m *peerMap) endpointForNodeKey(nk key.NodePublic) (ep *endpoint, ok bool) {
if nk.IsZero() {
return nil, false
}
if info, ok := m.byNodeKey[nk]; ok {
return info.ep, true
}
return nil, false
}
// endpointForNodeID returns the endpoint for nodeID, or nil if
// nodeID is not known to us.
func (m *peerMap) endpointForNodeID(nodeID tailcfg.NodeID) (ep *endpoint, ok bool) {
if info, ok := m.byNodeID[nodeID]; ok {
return info.ep, true
}
return nil, false
}
// endpointForEpAddr returns the endpoint for the peer we
// believe to be at addr, or nil if we don't know of any such peer.
func (m *peerMap) endpointForEpAddr(addr epAddr) (ep *endpoint, ok bool) {
if info, ok := m.byEpAddr[addr]; ok {
return info.ep, true
}
return nil, false
}
// forEachEndpoint invokes f on every endpoint in m.
func (m *peerMap) forEachEndpoint(f func(ep *endpoint)) {
for _, pi := range m.byNodeKey {
f(pi.ep)
}
}
// forEachEndpointWithDiscoKey invokes f on every endpoint in m that has the
// provided DiscoKey until f returns false or there are no endpoints left to
// iterate.
func (m *peerMap) forEachEndpointWithDiscoKey(dk key.DiscoPublic, f func(*endpoint) (keepGoing bool)) {
for nk := range m.nodesOfDisco[dk] {
pi, ok := m.byNodeKey[nk]
if !ok {
// Unexpected. Data structures would have to
// be out of sync. But we don't have a logger
// here to log [unexpected], so just skip.
// Maybe log later once peerMap is merged back
// into Conn.
continue
}
if !f(pi.ep) {
return
}
}
}
// upsertEndpoint stores endpoint in the peerInfo for
// ep.publicKey, and updates indexes. m must already have a
// tailcfg.Node for ep.publicKey.
func (m *peerMap) upsertEndpoint(ep *endpoint, oldDiscoKey key.DiscoPublic) {
if ep.nodeID == 0 {
panic("internal error: upsertEndpoint called with zero NodeID")
}
pi, ok := m.byNodeKey[ep.publicKey]
if !ok {
pi = newPeerInfo(ep)
m.byNodeKey[ep.publicKey] = pi
}
m.byNodeID[ep.nodeID] = pi
epDisco := ep.disco.Load()
if epDisco == nil || oldDiscoKey != epDisco.key {
delete(m.nodesOfDisco[oldDiscoKey], ep.publicKey)
}
if ep.isWireguardOnly {
// If the peer is a WireGuard only peer, return early. There is no disco
// tracking for WireGuard peers.
return
}
discoSet := m.nodesOfDisco[epDisco.key]
if discoSet == nil {
discoSet = set.Set[key.NodePublic]{}
m.nodesOfDisco[epDisco.key] = discoSet
}
discoSet.Add(ep.publicKey)
}
// clearEpAddrForNodeKey clears the [epAddr] associated with nk. This is
// called by an [*endpoint] when wireguard-go signals a mismatch between
// a Cryptokey Routing identification outcome and the peer we believe to be
// associated with the packet.
//
// NATs (including UDP relay servers) can cause collisions of [epAddr]s across
// peers. This function resolves such collisions when they occur. A subsequent
// lookup via endpointForEpAddr() will fail, leading to resolution via
// [*lazyEndpoint] Cryptokey Routing identification.
func (m *peerMap) clearEpAddrForNodeKey(nk key.NodePublic) {
if pi := m.byNodeKey[nk]; pi != nil {
delete(m.byEpAddr, pi.epAddr)
}
}
// setNodeKeyForEpAddr makes future peer lookups by addr return the
// same endpoint as a lookup by nk.
//
// This should only be called with a fully verified mapping of addr to
// nk, because calling this function defines the endpoint we hand to
// WireGuard for packets received from addr.
func (m *peerMap) setNodeKeyForEpAddr(addr epAddr, nk key.NodePublic) {
if pi := m.byEpAddr[addr]; pi != nil {
pi.epAddr = epAddr{}
delete(m.byEpAddr, addr)
}
if pi, ok := m.byNodeKey[nk]; ok {
pi.epAddr = addr
m.byEpAddr[addr] = pi
}
}
// deleteEndpoint deletes the peerInfo associated with ep, and
// updates indexes.
func (m *peerMap) deleteEndpoint(ep *endpoint) {
if ep == nil {
return
}
ep.stopAndReset()
epDisco := ep.disco.Load()
pi := m.byNodeKey[ep.publicKey]
if epDisco != nil {
delete(m.nodesOfDisco[epDisco.key], ep.publicKey)
}
delete(m.byNodeKey, ep.publicKey)
if was, ok := m.byNodeID[ep.nodeID]; ok && was.ep == ep {
delete(m.byNodeID, ep.nodeID)
}
if pi == nil {
// Kneejerk paranoia from earlier issue 2801.
// Unexpected. But no logger plumbed here to log so.
return
}
delete(m.byEpAddr, pi.epAddr)
}