net/{packet,tstun},wgengine: update disco key when receiving via TSMP

When receiving a TSMPDiscoAdvertisement from peer, update the discokey
for said peer.

Some parts taken from: https://github.com/tailscale/tailscale/pull/18073/

Updates #12639

Signed-off-by: Claus Lensbøl <claus@tailscale.com>

Co-authored-by: James Tucker <james@tailscale.com>
This commit is contained in:
Claus Lensbøl
2025-12-08 14:39:20 -05:00
parent d349370e55
commit 7732fa0fcf
6 changed files with 119 additions and 4 deletions

View File

@@ -271,7 +271,7 @@ func (h TSMPPongReply) Marshal(buf []byte) error {
// - 'a' (TSMPTypeDiscoAdvertisement)
// - 32 disco key bytes
type TSMPDiscoKeyAdvertisement struct {
Src, Dst netip.Addr
Src, Dst netip.Addr // Src and Dst are set from the parent IP Header when parsing.
Key key.DiscoPublic
}
@@ -298,7 +298,7 @@ func (ka *TSMPDiscoKeyAdvertisement) Marshal() ([]byte, error) {
return []byte{}, fmt.Errorf("expected payload length 33, got %d", len(payload))
}
return Generate(iph, payload), nil
return Generate(iph, payload[:]), nil
}
func (pp *Parsed) AsTSMPDiscoAdvertisement() (tka TSMPDiscoKeyAdvertisement, ok bool) {
@@ -310,6 +310,7 @@ func (pp *Parsed) AsTSMPDiscoAdvertisement() (tka TSMPDiscoKeyAdvertisement, ok
return
}
tka.Src = pp.Src.Addr()
tka.Dst = pp.Dst.Addr()
tka.Key = key.DiscoPublicFromRaw32(mem.B(p[1:33]))
return tka, true

View File

@@ -1126,8 +1126,10 @@ func (t *Wrapper) injectedRead(res tunInjectedRead, outBuffs [][]byte, sizes []i
return n, err
}
// DiscoKeyAdvertisement is a TSMP message used for distributing disco keys.
// This struct is used an an event on the [eventbus.Bus].
type DiscoKeyAdvertisement struct {
Src netip.Addr
Src netip.Addr // Src field is populated by the IP header of the packet, not from the payload itself.
Key key.DiscoPublic
}

View File

@@ -986,7 +986,7 @@ func TestTSMPDisco(t *testing.T) {
if tda.Src != src {
t.Errorf("Src address did not match, expected %v, got %v", src, tda.Src)
}
if !reflect.DeepEqual(tda.Key, discoKey.Public()) {
if tda.Key.Compare(discoKey.Public()) != 0 {
t.Errorf("Key did not match, expected %q, got %q", discoKey.Public(), tda.Key)
}
})

View File

@@ -4104,6 +4104,11 @@ var (
metricUDPLifetimeCycleCompleteAt10sCliff = newUDPLifetimeCounter("magicsock_udp_lifetime_cycle_complete_at_10s_cliff")
metricUDPLifetimeCycleCompleteAt30sCliff = newUDPLifetimeCounter("magicsock_udp_lifetime_cycle_complete_at_30s_cliff")
metricUDPLifetimeCycleCompleteAt60sCliff = newUDPLifetimeCounter("magicsock_udp_lifetime_cycle_complete_at_60s_cliff")
// TSMP disco key exchange
metricTSMPDiscoKeyAdvertisementReceived = clientmetric.NewCounter("magicsock_tsmp_disco_key_advertisement_received")
metricTSMPDiscoKeyAdvertisementApplied = clientmetric.NewCounter("magicsock_tsmp_disco_key_advertisement_applied")
metricTSMPDiscoKeyAdvertisementUnknown = clientmetric.NewCounter("magicsock_tsmp_disco_key_advertisement_unknown_peer")
)
// newUDPLifetimeCounter returns a new *clientmetric.Metric with the provided
@@ -4264,3 +4269,55 @@ func (c *Conn) PeerRelays() set.Set[netip.Addr] {
}
return servers
}
// HandleDiscoKeyAdvertisement processes a TSMP disco key update.
// The update may be solicited (in response to a request) or unsolicited.
// srcIP is the Tailscale IP of the peer that sent the update.
func (c *Conn) HandleDiscoKeyAdvertisement(srcIP netip.Addr, update packet.TSMPDiscoKeyAdvertisement) {
discoKey := update.Key
c.logf("magicsock: received disco key update %v from %v", discoKey.ShortString(), srcIP)
metricTSMPDiscoKeyAdvertisementReceived.Add(1)
c.mu.Lock()
defer c.mu.Unlock()
var nodeKey key.NodePublic
var found bool
for _, peer := range c.peers.All() {
for _, addr := range peer.Addresses().All() {
if addr.Addr() == srcIP {
nodeKey = peer.Key()
found = true
break
}
}
if found {
break
}
}
if !found {
c.logf("magicsock: disco key update from unknown peer %v", srcIP)
metricTSMPDiscoKeyAdvertisementUnknown.Add(1)
return
}
ep, ok := c.peerMap.endpointForNodeKey(nodeKey)
if !ok {
c.logf("magicsock: endpoint not found for node %v", nodeKey.ShortString())
return
}
oldDiscoKey := key.DiscoPublic{}
if epDisco := ep.disco.Load(); epDisco != nil {
oldDiscoKey = epDisco.key
}
c.discoInfoForKnownPeerLocked(discoKey)
ep.disco.Store(&endpointDisco{
key: discoKey,
short: discoKey.ShortString(),
})
c.peerMap.upsertEndpoint(ep, oldDiscoKey)
c.logf("magicsock: updated disco key for peer %v to %v", nodeKey.ShortString(), discoKey.ShortString())
metricTSMPDiscoKeyAdvertisementApplied.Add(1)
}

View File

@@ -64,6 +64,7 @@ import (
"tailscale.com/types/netmap"
"tailscale.com/types/nettype"
"tailscale.com/types/ptr"
"tailscale.com/types/views"
"tailscale.com/util/cibuild"
"tailscale.com/util/clientmetric"
"tailscale.com/util/eventbus"
@@ -4302,3 +4303,48 @@ func TestRotateDiscoKeyMultipleTimes(t *testing.T) {
keys = append(keys, newKey)
}
}
func TestReceiveTSMPDiscoKeyAdvertisement(t *testing.T) {
conn := newTestConn(t)
t.Cleanup(func() { conn.Close() })
peerKey := key.NewNode().Public()
ep := &endpoint{
nodeID: 1,
publicKey: peerKey,
nodeAddr: netip.MustParseAddr("100.64.0.1"),
}
discoKey := key.NewDisco().Public()
ep.disco.Store(&endpointDisco{
key: discoKey,
short: discoKey.ShortString(),
})
ep.c = conn
conn.mu.Lock()
conn.peers = views.SliceOf([]tailcfg.NodeView{
(&tailcfg.Node{
Key: ep.publicKey,
Addresses: []netip.Prefix{
netip.MustParsePrefix("100.64.0.1/32"),
},
}).View(),
})
conn.mu.Unlock()
conn.peerMap.upsertEndpoint(ep, key.DiscoPublic{})
if ep.discoShort() != discoKey.ShortString() {
t.Errorf("Original disco key %s, does not match %s", discoKey.ShortString(), ep.discoShort())
}
newDiscoKey := key.NewDisco().Public()
tka := packet.TSMPDiscoKeyAdvertisement{
Src: netip.MustParseAddr("100.64.0.1"),
Key: newDiscoKey,
}
conn.HandleDiscoKeyAdvertisement(netip.MustParseAddr("100.64.0.1"), tka)
if ep.disco.Load().short != newDiscoKey.ShortString() {
t.Errorf("New disco key %s, does not match %s", newDiscoKey.ShortString(), ep.disco.Load().short)
}
}

View File

@@ -551,6 +551,15 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error)
}
e.linkChangeQueue.Add(func() { e.linkChange(&cd) })
})
eventbus.SubscribeFunc(ec, func(update tstun.DiscoKeyAdvertisement) {
e.logf("wgengine: got TSMP disco key advertisement from %v via eventbus", update.Src)
if e.magicConn != nil {
pkt := packet.TSMPDiscoKeyAdvertisement{
Key: update.Key,
}
e.magicConn.HandleDiscoKeyAdvertisement(update.Src, pkt)
}
})
e.eventClient = ec
e.logf("Engine created.")
return e, nil