net,wgengine: add support for disco key exchnage via TSMP

Updates tailscale/corp#34037

Signed-off-by: James Tucker <james@tailscale.com>
This commit is contained in:
James Tucker
2025-11-03 14:53:11 -08:00
parent 3b865d7c33
commit adf7bbf902
8 changed files with 818 additions and 12 deletions

View File

@@ -18,7 +18,7 @@ import (
"tailscale.com/types/ipproto"
)
const minTSMPSize = 7 // the rejected body is 7 bytes
const minTSMPSize = 1 // minimum is 1 byte for the type field (e.g., disco key request 'd')
// TailscaleRejectedHeader is a TSMP message that says that one
// Tailscale node has rejected the connection from another. Unlike a
@@ -72,6 +72,12 @@ const (
// TSMPTypePong is the type byte for a TailscalePongResponse.
TSMPTypePong TSMPType = 'o'
// TSMPTypeDiscoKeyRequest is the type byte for a disco key request.
TSMPTypeDiscoKeyRequest TSMPType = 'd'
// TSMPTypeDiscoKeyUpdate is the type byte for a disco key update.
TSMPTypeDiscoKeyUpdate TSMPType = 'D'
)
type TailscaleRejectReason byte
@@ -259,3 +265,63 @@ func (h TSMPPongReply) Marshal(buf []byte) error {
binary.BigEndian.PutUint16(buf[9:11], h.PeerAPIPort)
return nil
}
// TSMPDiscoKeyRequest is a TSMP message that requests a peer's disco key.
//
// On the wire, after the IP header, it's currently 1 byte:
// - 'd' (TSMPTypeDiscoKeyRequest)
type TSMPDiscoKeyRequest struct{}
func (pp *Parsed) AsTSMPDiscoKeyRequest() (h TSMPDiscoKeyRequest, ok bool) {
if pp.IPProto != ipproto.TSMP {
return
}
p := pp.Payload()
if len(p) < 1 || p[0] != byte(TSMPTypeDiscoKeyRequest) {
return
}
return h, true
}
// TSMPDiscoKeyUpdate is a TSMP message that contains a disco public key.
// It may be sent in response to a request, or unsolicited when a node
// believes its peer may have stale disco key information.
//
// On the wire, after the IP header, it's currently 33 bytes:
// - 'D' (TSMPTypeDiscoKeyUpdate)
// - 32 bytes disco public key
type TSMPDiscoKeyUpdate struct {
IPHeader Header
DiscoKey [32]byte // raw disco public key bytes
}
// AsTSMPDiscoKeyUpdate returns pp as a TSMPDiscoKeyUpdate and whether it is one.
// The update.IPHeader field is not populated.
func (pp *Parsed) AsTSMPDiscoKeyUpdate() (update TSMPDiscoKeyUpdate, ok bool) {
if pp.IPProto != ipproto.TSMP {
return
}
p := pp.Payload()
if len(p) < 33 || p[0] != byte(TSMPTypeDiscoKeyUpdate) {
return
}
copy(update.DiscoKey[:], p[1:33])
return update, true
}
func (h TSMPDiscoKeyUpdate) Len() int {
return h.IPHeader.Len() + 33
}
func (h TSMPDiscoKeyUpdate) Marshal(buf []byte) error {
if len(buf) < h.Len() {
return errSmallBuffer
}
if err := h.IPHeader.Marshal(buf); err != nil {
return err
}
buf = buf[h.IPHeader.Len():]
buf[0] = byte(TSMPTypeDiscoKeyUpdate)
copy(buf[1:33], h.DiscoKey[:])
return nil
}

View File

@@ -71,3 +71,171 @@ func TestTailscaleRejectedHeader(t *testing.T) {
}
}
}
func TestTSMPDiscoKeyRequest(t *testing.T) {
t.Run("Manual", func(t *testing.T) {
var payload [1]byte
payload[0] = byte(TSMPTypeDiscoKeyRequest)
var p Parsed
p.IPProto = TSMP
p.dataofs = 40 // simulate after IP header
buf := make([]byte, 40+1)
copy(buf[40:], payload[:])
p.b = buf
p.length = len(buf)
_, ok := p.AsTSMPDiscoKeyRequest()
if !ok {
t.Fatal("failed to parse TSMP disco key request")
}
})
t.Run("RoundTripIPv4", func(t *testing.T) {
src := netip.MustParseAddr("100.64.0.1")
dst := netip.MustParseAddr("100.64.0.2")
iph := IP4Header{
IPProto: TSMP,
Src: src,
Dst: dst,
}
var payload [1]byte
payload[0] = byte(TSMPTypeDiscoKeyRequest)
pkt := Generate(iph, payload[:])
t.Logf("Generated packet: %d bytes, hex: %x", len(pkt), pkt)
// Manually check what decode4 would see
if len(pkt) >= 4 {
declaredLen := int(uint16(pkt[2])<<8 | uint16(pkt[3]))
t.Logf("Packet buffer length: %d, IP header declares length: %d", len(pkt), declaredLen)
t.Logf("Protocol byte at [9]: 0x%02x = %d", pkt[9], pkt[9])
}
var p Parsed
p.Decode(pkt)
t.Logf("Decoded: IPVersion=%d IPProto=%v Src=%v Dst=%v length=%d dataofs=%d",
p.IPVersion, p.IPProto, p.Src, p.Dst, p.length, p.dataofs)
if p.IPVersion != 4 {
t.Errorf("IPVersion = %d, want 4", p.IPVersion)
}
if p.IPProto != TSMP {
t.Errorf("IPProto = %v, want TSMP", p.IPProto)
}
if p.Src.Addr() != src {
t.Errorf("Src = %v, want %v", p.Src.Addr(), src)
}
if p.Dst.Addr() != dst {
t.Errorf("Dst = %v, want %v", p.Dst.Addr(), dst)
}
_, ok := p.AsTSMPDiscoKeyRequest()
if !ok {
t.Fatal("failed to parse TSMP disco key request from generated packet")
}
})
t.Run("RoundTripIPv6", func(t *testing.T) {
src := netip.MustParseAddr("2001:db8::1")
dst := netip.MustParseAddr("2001:db8::2")
iph := IP6Header{
IPProto: TSMP,
Src: src,
Dst: dst,
}
var payload [1]byte
payload[0] = byte(TSMPTypeDiscoKeyRequest)
pkt := Generate(iph, payload[:])
t.Logf("Generated packet: %d bytes", len(pkt))
var p Parsed
p.Decode(pkt)
if p.IPVersion != 6 {
t.Errorf("IPVersion = %d, want 6", p.IPVersion)
}
if p.IPProto != TSMP {
t.Errorf("IPProto = %v, want TSMP", p.IPProto)
}
if p.Src.Addr() != src {
t.Errorf("Src = %v, want %v", p.Src.Addr(), src)
}
if p.Dst.Addr() != dst {
t.Errorf("Dst = %v, want %v", p.Dst.Addr(), dst)
}
_, ok := p.AsTSMPDiscoKeyRequest()
if !ok {
t.Fatal("failed to parse TSMP disco key request from generated packet")
}
})
}
func TestTSMPDiscoKeyUpdate(t *testing.T) {
var discoKey [32]byte
for i := range discoKey {
discoKey[i] = byte(i + 10)
}
// Test IPv4
t.Run("IPv4", func(t *testing.T) {
update := TSMPDiscoKeyUpdate{
IPHeader: IP4Header{
IPProto: TSMP,
Src: netip.MustParseAddr("1.2.3.4"),
Dst: netip.MustParseAddr("5.6.7.8"),
},
DiscoKey: discoKey,
}
pkt := make([]byte, update.Len())
if err := update.Marshal(pkt); err != nil {
t.Fatal(err)
}
var p Parsed
p.Decode(pkt)
parsed, ok := p.AsTSMPDiscoKeyUpdate()
if !ok {
t.Fatal("failed to parse TSMP disco key update")
}
if parsed.DiscoKey != discoKey {
t.Errorf("disco key mismatch: got %v, want %v", parsed.DiscoKey, discoKey)
}
})
// Test IPv6
t.Run("IPv6", func(t *testing.T) {
update := TSMPDiscoKeyUpdate{
IPHeader: IP6Header{
IPProto: TSMP,
Src: netip.MustParseAddr("2001:db8::1"),
Dst: netip.MustParseAddr("2001:db8::2"),
},
DiscoKey: discoKey,
}
pkt := make([]byte, update.Len())
if err := update.Marshal(pkt); err != nil {
t.Fatal(err)
}
var p Parsed
p.Decode(pkt)
parsed, ok := p.AsTSMPDiscoKeyUpdate()
if !ok {
t.Fatal("failed to parse TSMP disco key update")
}
if parsed.DiscoKey != discoKey {
t.Errorf("disco key mismatch: got %v, want %v", parsed.DiscoKey, discoKey)
}
})
}

View File

@@ -188,11 +188,19 @@ type Wrapper struct {
// OnTSMPPongReceived, if non-nil, is called whenever a TSMP pong arrives.
OnTSMPPongReceived func(packet.TSMPPongReply)
// OnTSMPDiscoKeyReceived, if non-nil, is called whenever a TSMP disco key update arrives.
// The srcIP parameter identifies the peer that sent the update.
OnTSMPDiscoKeyReceived func(srcIP netip.Addr, update packet.TSMPDiscoKeyUpdate)
// OnICMPEchoResponseReceived, if non-nil, is called whenever a ICMP echo response
// arrives. If the packet is to be handled internally this returns true,
// false otherwise.
OnICMPEchoResponseReceived func(*packet.Parsed) bool
// GetDiscoPublicKey, if non-nil, returns the local node's disco public key.
// This is called when responding to TSMP disco key requests.
GetDiscoPublicKey func() key.DiscoPublic
// PeerAPIPort, if non-nil, returns the peerapi port that's
// running for the given IP address.
PeerAPIPort func(netip.Addr) (port uint16, ok bool)
@@ -1132,6 +1140,15 @@ func (t *Wrapper) filterPacketInboundFromWireGuard(p *packet.Parsed, captHook pa
if f := t.OnTSMPPongReceived; f != nil {
f(data)
}
} else if _, ok := p.AsTSMPDiscoKeyRequest(); ok {
t.noteActivity()
t.injectOutboundDiscoKeyUpdate(p)
return filter.DropSilently, gro
} else if discoKeyUpdate, ok := p.AsTSMPDiscoKeyUpdate(); ok {
if f := t.OnTSMPDiscoKeyReceived; f != nil {
f(p.Src.Addr(), discoKeyUpdate)
}
return filter.DropSilently, gro
}
}
@@ -1440,6 +1457,36 @@ func (t *Wrapper) injectOutboundPong(pp *packet.Parsed, req packet.TSMPPingReque
t.InjectOutbound(packet.Generate(pong, nil))
}
func (t *Wrapper) injectOutboundDiscoKeyUpdate(pp *packet.Parsed) {
if t.GetDiscoPublicKey == nil {
return
}
discoKey := t.GetDiscoPublicKey()
if discoKey.IsZero() {
return
}
update := packet.TSMPDiscoKeyUpdate{
DiscoKey: discoKey.Raw32(),
}
switch pp.IPVersion {
case 4:
h4 := pp.IP4Header()
h4.ToResponse()
update.IPHeader = h4
case 6:
h6 := pp.IP6Header()
h6.ToResponse()
update.IPHeader = h6
default:
return
}
t.InjectOutbound(packet.Generate(update, nil))
}
// InjectOutbound makes the Wrapper device behave as if a packet
// with the given contents was sent to the network.
// It does not block, but takes ownership of the packet.

View File

@@ -721,6 +721,16 @@ func (c *Conn) processDERPReadResult(dm derpReadResult, b []byte) (n int, ep *en
update(0, netip.AddrPortFrom(ep.nodeAddr, 0), srcAddr.ap, 1, dm.n, true)
}
// Request disco key from peer via TSMP if we receive a WireGuard handshake
// over DERP without recent disco success. This handles the "WireGuard-first"
// case where WireGuard establishes a tunnel via DERP before disco succeeds
// (e.g., control plane unreachable or stale disco keys).
// We only trigger on data packets (not handshake packets) because the tunnel
// must be fully established before we can send TSMP requests through it.
if looksLikeWireGuardHandshake(b[:n]) && n > 0 {
go c.requestDiscoKeyViaTSMP(dm.src, ep)
}
c.metrics.inboundPacketsDERPTotal.Add(1)
c.metrics.inboundBytesDERPTotal.Add(int64(n))
return n, ep

View File

@@ -155,17 +155,18 @@ type Conn struct {
// This block mirrors the contents and field order of the Options
// struct. Initialized once at construction, then constant.
eventBus *eventbus.Bus
eventClient *eventbus.Client
logf logger.Logf
epFunc func([]tailcfg.Endpoint)
derpActiveFunc func()
idleFunc func() time.Duration // nil means unknown
testOnlyPacketListener nettype.PacketListener
noteRecvActivity func(key.NodePublic) // or nil, see Options.NoteRecvActivity
netMon *netmon.Monitor // must be non-nil
health *health.Tracker // or nil
controlKnobs *controlknobs.Knobs // or nil
eventBus *eventbus.Bus
eventClient *eventbus.Client
logf logger.Logf
epFunc func([]tailcfg.Endpoint)
derpActiveFunc func()
idleFunc func() time.Duration // nil means unknown
testOnlyPacketListener nettype.PacketListener
noteRecvActivity func(key.NodePublic) // or nil, see Options.NoteRecvActivity
sendTSMPDiscoKeyRequest func(netip.Addr) error // or nil, sends TSMP disco key request to peer
netMon *netmon.Monitor // must be non-nil
health *health.Tracker // or nil
controlKnobs *controlknobs.Knobs // or nil
// ================================================================
// No locking required to access these fields, either because
@@ -1800,6 +1801,15 @@ func looksLikeInitiationMsg(b []byte) bool {
binary.LittleEndian.Uint32(b) == device.MessageInitiationType
}
func looksLikeWireGuardHandshake(b []byte) bool {
if len(b) < 4 {
return false
}
msgType := binary.LittleEndian.Uint32(b)
return (len(b) == device.MessageInitiationSize && msgType == device.MessageInitiationType) ||
(len(b) == device.MessageResponseSize && msgType == device.MessageResponseType)
}
// receiveIP is the shared bits of ReceiveIPv4 and ReceiveIPv6.
//
// size is the length of 'b' to report up to wireguard-go (only relevant if
@@ -2857,6 +2867,14 @@ func (c *Conn) SetSilentDisco(v bool) {
})
}
// SetSendTSMPDiscoKeyRequest sets the callback function to send TSMP disco key requests.
// This is provided by the engine/tundev to inject TSMP packets.
func (c *Conn) SetSendTSMPDiscoKeyRequest(fn func(netip.Addr) error) {
c.mu.Lock()
defer c.mu.Unlock()
c.sendTSMPDiscoKeyRequest = fn
}
// SilentDisco returns true if silent disco is enabled, otherwise false.
func (c *Conn) SilentDisco() bool {
c.mu.Lock()
@@ -4104,6 +4122,13 @@ var (
metricUDPLifetimeCycleCompleteAt10sCliff = newUDPLifetimeCounter("magicsock_udp_lifetime_cycle_complete_at_10s_cliff")
metricUDPLifetimeCycleCompleteAt30sCliff = newUDPLifetimeCounter("magicsock_udp_lifetime_cycle_complete_at_30s_cliff")
metricUDPLifetimeCycleCompleteAt60sCliff = newUDPLifetimeCounter("magicsock_udp_lifetime_cycle_complete_at_60s_cliff")
// TSMP disco key exchange
metricTSMPDiscoKeyRequestSent = clientmetric.NewCounter("magicsock_tsmp_disco_key_request_sent")
metricTSMPDiscoKeyRequestError = clientmetric.NewCounter("magicsock_tsmp_disco_key_request_error")
metricTSMPDiscoKeyUpdateReceived = clientmetric.NewCounter("magicsock_tsmp_disco_key_update_received")
metricTSMPDiscoKeyUpdateApplied = clientmetric.NewCounter("magicsock_tsmp_disco_key_update_applied")
metricTSMPDiscoKeyUpdateUnknown = clientmetric.NewCounter("magicsock_tsmp_disco_key_update_unknown_peer")
)
// newUDPLifetimeCounter returns a new *clientmetric.Metric with the provided
@@ -4242,6 +4267,101 @@ func (le *lazyEndpoint) FromPeer(peerPublicKey [32]byte) {
// See http://go/corp/29422 & http://go/corp/30042
le.c.peerMap.setNodeKeyForEpAddr(le.src, pubKey)
le.c.logf("magicsock: lazyEndpoint.FromPeer(%v) setting epAddr(%v) in peerMap for node(%v)", pubKey.ShortString(), le.src, ep.nodeAddr)
// Request disco key from peer via TSMP if we establish a tunnel
// without a recent disco ping. This handles cases where WireGuard
// establishes a tunnel before disco succeeds (e.g., control plane
// unreachable or stale disco keys).
go le.c.requestDiscoKeyViaTSMP(pubKey, ep)
}
// requestDiscoKeyViaTSMP sends a TSMP disco key request to a peer if there
// hasn't been a recent disco ping.
func (c *Conn) requestDiscoKeyViaTSMP(nodeKey key.NodePublic, ep *endpoint) {
if c.sendTSMPDiscoKeyRequest == nil {
return
}
if !ep.nodeAddr.IsValid() {
return
}
epDisco := ep.disco.Load()
if epDisco != nil {
c.mu.Lock()
di := c.discoInfo[epDisco.key]
recentDiscoPing := di != nil && time.Since(di.lastPingTime) < discoPingInterval
c.mu.Unlock()
if recentDiscoPing {
return
}
}
// YUCK. once again goroutines fight back - we need to deterministically
// schedule this _after_ the wireguard handshake response or else we trigger
// the wireguard handshake race problem. Maybe it's ok though, as we should
// really be singleflighting this, and perhaps we just use a singleflight
// with a short cork.
time.Sleep(time.Millisecond)
c.logf("magicsock: sending TSMP disco key request to %v (%v)", nodeKey.ShortString(), ep.nodeAddr)
if err := c.sendTSMPDiscoKeyRequest(ep.nodeAddr); err != nil {
c.logf("magicsock: failed to send TSMP disco key request: %v", err)
metricTSMPDiscoKeyRequestError.Add(1)
return
}
metricTSMPDiscoKeyRequestSent.Add(1)
}
// HandleDiscoKeyUpdate processes a TSMP disco key update.
// The update may be solicited (in response to a request) or unsolicited.
// srcIP is the Tailscale IP of the peer that sent the update.
func (c *Conn) HandleDiscoKeyUpdate(srcIP netip.Addr, update packet.TSMPDiscoKeyUpdate) {
discoKey := key.DiscoPublicFromRaw32(mem.B(update.DiscoKey[:]))
c.logf("magicsock: received disco key update %v from %v", discoKey.ShortString(), srcIP)
metricTSMPDiscoKeyUpdateReceived.Add(1)
c.mu.Lock()
defer c.mu.Unlock()
var nodeKey key.NodePublic
var found bool
for _, peer := range c.peers.All() {
for _, addr := range peer.Addresses().All() {
if addr.Addr() == srcIP {
nodeKey = peer.Key()
found = true
break
}
}
if found {
break
}
}
if !found {
c.logf("magicsock: disco key update from unknown peer %v", srcIP)
metricTSMPDiscoKeyUpdateUnknown.Add(1)
return
}
ep, ok := c.peerMap.endpointForNodeKey(nodeKey)
if !ok {
c.logf("magicsock: endpoint not found for node %v", nodeKey.ShortString())
return
}
oldDiscoKey := key.DiscoPublic{}
if epDisco := ep.disco.Load(); epDisco != nil {
oldDiscoKey = epDisco.key
}
c.discoInfoForKnownPeerLocked(discoKey)
ep.disco.Store(&endpointDisco{
key: discoKey,
short: discoKey.ShortString(),
})
c.peerMap.upsertEndpoint(ep, oldDiscoKey)
c.logf("magicsock: updated disco key for peer %v to %v", nodeKey.ShortString(), discoKey.ShortString())
metricTSMPDiscoKeyUpdateApplied.Add(1)
}
// PeerRelays returns the current set of candidate peer relays.

View File

@@ -64,6 +64,7 @@ import (
"tailscale.com/types/netmap"
"tailscale.com/types/nettype"
"tailscale.com/types/ptr"
"tailscale.com/types/views"
"tailscale.com/util/cibuild"
"tailscale.com/util/clientmetric"
"tailscale.com/util/eventbus"
@@ -4305,3 +4306,58 @@ func TestRotateDiscoKeyMultipleTimes(t *testing.T) {
keys = append(keys, newKey)
}
}
func TestSendTSMPDiscoKeyRequest(t *testing.T) {
ep := &endpoint{
nodeID: 1,
publicKey: key.NewNode().Public(),
nodeAddr: netip.MustParseAddr("100.64.0.1"),
}
discoKey := key.NewDisco().Public()
ep.disco.Store(&endpointDisco{
key: discoKey,
short: discoKey.ShortString(),
})
conn := newConn(t.Logf)
ep.c = conn
tsmpRequestCalled := make(chan struct{}, 1)
var capturedIP netip.Addr
conn.sendTSMPDiscoKeyRequest = func(ip netip.Addr) error {
capturedIP = ip
tsmpRequestCalled <- struct{}{}
return nil
}
conn.mu.Lock()
conn.peers = views.SliceOf([]tailcfg.NodeView{
(&tailcfg.Node{
Key: ep.publicKey,
Addresses: []netip.Prefix{
netip.MustParsePrefix("100.64.0.1/32"),
},
}).View(),
})
conn.mu.Unlock()
var pubKey [32]byte
copy(pubKey[:], ep.publicKey.AppendTo(nil))
conn.peerMap.upsertEndpoint(ep, key.DiscoPublic{})
le := &lazyEndpoint{
c: conn,
src: epAddr{ap: netip.MustParseAddrPort("127.0.0.1:7777")},
}
le.FromPeer(pubKey)
select {
case <-tsmpRequestCalled:
if !capturedIP.IsValid() {
t.Error("TSMP request sent with invalid IP")
}
t.Logf("TSMP disco key request sent to %v", capturedIP)
case <-time.After(time.Second):
t.Error("TSMP disco key request was not sent")
}
}

View File

@@ -0,0 +1,291 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package magicsock
import (
"net/netip"
"testing"
"time"
"github.com/tailscale/wireguard-go/tun/tuntest"
"tailscale.com/net/netaddr"
"tailscale.com/net/packet"
"tailscale.com/tailcfg"
"tailscale.com/tstest"
"tailscale.com/types/ipproto"
"tailscale.com/types/key"
"tailscale.com/types/netmap"
"tailscale.com/util/set"
"tailscale.com/wgengine/wgcfg/nmcfg"
)
func TestTSMPDiscoKeyExchange(t *testing.T) {
tstest.ResourceCheck(t)
// Set up DERP and STUN servers
derpMap, cleanup := runDERPAndStun(t, t.Logf, localhostListener{}, netaddr.IPv4(127, 0, 0, 1))
defer cleanup()
// Create two magicsock peers
m1 := newMagicStack(t, t.Logf, localhostListener{}, derpMap)
defer m1.Close()
m2 := newMagicStack(t, t.Logf, localhostListener{}, derpMap)
defer m2.Close()
// Wire up TSMP hooks to enable disco key exchange
// This mimics what userspaceEngine does in wgengine/userspace.go
// Hook 0: GetDiscoPublicKey - allows TSMP replies to include current disco key
m1.tsTun.GetDiscoPublicKey = m1.conn.DiscoPublicKey
m2.tsTun.GetDiscoPublicKey = m2.conn.DiscoPublicKey
// Hook 1: OnTSMPDiscoKeyReceived - handle incoming TSMP disco key updates
m1.tsTun.OnTSMPDiscoKeyReceived = func(srcIP netip.Addr, update packet.TSMPDiscoKeyUpdate) {
t.Logf("m1: received TSMP disco key update from %v", srcIP)
m1.conn.HandleDiscoKeyUpdate(srcIP, update)
}
m2.tsTun.OnTSMPDiscoKeyReceived = func(srcIP netip.Addr, update packet.TSMPDiscoKeyUpdate) {
t.Logf("m2: received TSMP disco key update from %v", srcIP)
m2.conn.HandleDiscoKeyUpdate(srcIP, update)
}
sendTSMPDiscoKeyRequest := func(dstIP netip.Addr) error {
var srcIP netip.Addr
var stack *magicStack
switch dstIP {
case m1.IP():
srcIP = m2.IP()
stack = m2
t.Logf("m2: sending disco key request to m1")
case m2.IP():
srcIP = m1.IP()
stack = m1
t.Logf("m1: sending disco key request to m2")
}
// equivalent to the implementation in userspace.Engine
iph := packet.IP4Header{
IPProto: ipproto.TSMP,
Src: srcIP,
Dst: dstIP,
}
var tsmpPayload [1]byte
tsmpPayload[0] = byte(packet.TSMPTypeDiscoKeyRequest)
tsmpRequest := packet.Generate(iph, tsmpPayload[:])
return stack.tsTun.InjectOutbound(tsmpRequest)
}
// Hook 2: SetSendTSMPDiscoKeyRequest - send TSMP disco key requests
m1.conn.SetSendTSMPDiscoKeyRequest(sendTSMPDiscoKeyRequest)
m2.conn.SetSendTSMPDiscoKeyRequest(sendTSMPDiscoKeyRequest)
// Get initial disco keys
disco1Original := m1.conn.DiscoPublicKey()
disco2 := m2.conn.DiscoPublicKey()
t.Logf("m1: node=%v disco=%v", m1.Public().ShortString(), disco1Original.ShortString())
t.Logf("m2: node=%v disco=%v", m2.Public().ShortString(), disco2.ShortString())
// Wait for initial endpoints
var eps1, eps2 []tailcfg.Endpoint
select {
case eps1 = <-m1.epCh:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for m1 endpoints")
}
select {
case eps2 = <-m2.epCh:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for m2 endpoints")
}
// Build initial network maps and establish connection
nm1 := &netmap.NetworkMap{
NodeKey: m1.Public(),
SelfNode: (&tailcfg.Node{
Addresses: []netip.Prefix{netip.PrefixFrom(netaddr.IPv4(100, 64, 0, 1), 32)},
}).View(),
Peers: []tailcfg.NodeView{
(&tailcfg.Node{
ID: 2,
Key: m2.Public(),
DiscoKey: disco2,
Addresses: []netip.Prefix{netip.PrefixFrom(netaddr.IPv4(100, 64, 0, 2), 32)},
AllowedIPs: []netip.Prefix{netip.PrefixFrom(netaddr.IPv4(100, 64, 0, 2), 32)},
Endpoints: epFromTyped(eps2),
HomeDERP: 1,
}).View(),
},
}
nm2 := &netmap.NetworkMap{
NodeKey: m2.Public(),
SelfNode: (&tailcfg.Node{
Addresses: []netip.Prefix{netip.PrefixFrom(netaddr.IPv4(100, 64, 0, 2), 32)},
}).View(),
Peers: []tailcfg.NodeView{
(&tailcfg.Node{
ID: 1,
Key: m1.Public(),
DiscoKey: disco1Original,
Addresses: []netip.Prefix{netip.PrefixFrom(netaddr.IPv4(100, 64, 0, 1), 32)},
AllowedIPs: []netip.Prefix{netip.PrefixFrom(netaddr.IPv4(100, 64, 0, 1), 32)},
Endpoints: epFromTyped(eps1),
HomeDERP: 1,
}).View(),
},
}
cfg1, err := nmcfg.WGCfg(m1.privateKey, nm1, t.Logf, 0, "")
if err != nil {
t.Fatal(err)
}
cfg2, err := nmcfg.WGCfg(m2.privateKey, nm2, t.Logf, 0, "")
if err != nil {
t.Fatal(err)
}
nv1 := NodeViewsUpdate{
SelfNode: nm1.SelfNode,
Peers: nm1.Peers,
}
m1.conn.onNodeViewsUpdate(nv1)
peerSet1 := set.Set[key.NodePublic]{}
peerSet1.Add(m2.Public())
m1.conn.UpdatePeers(peerSet1)
nv2 := NodeViewsUpdate{
SelfNode: nm2.SelfNode,
Peers: nm2.Peers,
}
m2.conn.onNodeViewsUpdate(nv2)
peerSet2 := set.Set[key.NodePublic]{}
peerSet2.Add(m1.Public())
m2.conn.UpdatePeers(peerSet2)
if err := m1.Reconfig(cfg1); err != nil {
t.Fatal(err)
}
if err := m2.Reconfig(cfg2); err != nil {
t.Fatal(err)
}
t.Logf("=== INITIAL CONFIGURATION COMPLETE ===")
// Start goroutines to drain TUN inbound channels so TSMP packets can be received
drainTun := func(name string, stack *magicStack) {
go func() {
for {
select {
case <-t.Context().Done():
return
case pkt := <-stack.tun.Inbound:
var p packet.Parsed
p.Decode(pkt)
if p.IPProto == ipproto.TSMP {
t.Logf("%s: received TSMP packet on TUN inbound: %d bytes", name, len(pkt))
} else if p.IPProto == ipproto.ICMPv4 {
t.Logf("%s: received ICMPv4 packet on TUN inbound: %d bytes", name, len(pkt))
} else {
t.Logf("%s: received packet on TUN inbound: %d bytes, proto=%v", name, len(pkt), p.IPProto)
}
}
}
}()
}
drainTun("m1", m1)
drainTun("m2", m2)
initialRequestsSent := metricTSMPDiscoKeyRequestSent.Value()
initialUpdatesReceived := metricTSMPDiscoKeyUpdateReceived.Value()
initialUpdatesApplied := metricTSMPDiscoKeyUpdateApplied.Value()
t.Logf("Initial metrics: requests_sent=%d updates_received=%d updates_applied=%d",
initialRequestsSent, initialUpdatesReceived, initialUpdatesApplied)
t.Logf("=== ROTATING m1's DISCO KEY ===")
m1.conn.RotateDiscoKey()
disco1New := m1.conn.DiscoPublicKey()
if disco1Original.Compare(disco1New) == 0 {
t.Fatal("disco key failed to rotate")
}
t.Logf("Rotated: %v -> %v", disco1Original.ShortString(), disco1New.ShortString())
t.Logf("=== SENDING PACKETS TO TRIGGER TSMP EXCHANGE ===")
ping1to2 := tuntest.Ping(netip.MustParseAddr("100.64.0.2"), netip.MustParseAddr("100.64.0.1"))
// Send packets from m2 to m1 only - this will trigger m1's handshake initiation
// and when m2 receives the encrypted packet, it should trigger FromPeer -> TSMP
select {
case m1.tun.Outbound <- ping1to2:
default:
}
for {
time.Sleep(time.Millisecond)
// Check if m2 has learned m1's new disco key
st := m2.Status()
if ps, ok := st.Peer[m1.Public()]; ok && ps.CurAddr != "" {
t.Logf("Connection established after disco key rotation")
t.Logf("m2 -> m1 via %v", ps.CurAddr)
t.Logf("Disco key rotation: %v -> %v", disco1Original.ShortString(), disco1New.ShortString())
// Verify TSMP metrics incremented
finalRequestsSent := metricTSMPDiscoKeyRequestSent.Value()
finalUpdatesReceived := metricTSMPDiscoKeyUpdateReceived.Value()
finalUpdatesApplied := metricTSMPDiscoKeyUpdateApplied.Value()
t.Logf("Final metrics: requests_sent=%d updates_received=%d updates_applied=%d",
finalRequestsSent, finalUpdatesReceived, finalUpdatesApplied)
// Check that at least one TSMP request was sent
if finalRequestsSent <= initialRequestsSent {
t.Errorf("Expected TSMP disco key request to be sent, but metric did not increment: %d -> %d",
initialRequestsSent, finalRequestsSent)
} else {
t.Logf("✓ TSMP disco key request sent (metric: %d -> %d)",
initialRequestsSent, finalRequestsSent)
}
// Check that at least one TSMP update was received
if finalUpdatesReceived <= initialUpdatesReceived {
t.Errorf("Expected TSMP disco key update to be received, but metric did not increment: %d -> %d",
initialUpdatesReceived, finalUpdatesReceived)
} else {
t.Logf("✓ TSMP disco key update received (metric: %d -> %d)",
initialUpdatesReceived, finalUpdatesReceived)
}
// Check that at least one TSMP update was applied
if finalUpdatesApplied <= initialUpdatesApplied {
t.Errorf("Expected TSMP disco key update to be applied, but metric did not increment: %d -> %d",
initialUpdatesApplied, finalUpdatesApplied)
} else {
t.Logf("✓ TSMP disco key update applied (metric: %d -> %d)",
initialUpdatesApplied, finalUpdatesApplied)
}
// Verify error counter didn't increment
requestErrors := metricTSMPDiscoKeyRequestError.Value()
if requestErrors > 0 {
t.Logf("Warning: TSMP disco key request errors: %d", requestErrors)
}
unknownPeers := metricTSMPDiscoKeyUpdateUnknown.Value()
if unknownPeers > 0 {
t.Logf("Warning: TSMP disco key updates from unknown peers: %d", unknownPeers)
}
t.Logf("TSMP disco key exchange infrastructure is functional")
return
}
}
}

View File

@@ -466,6 +466,25 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error)
return true
}
e.tundev.OnTSMPDiscoKeyReceived = func(srcIP netip.Addr, update packet.TSMPDiscoKeyUpdate) {
e.logf("wgengine: got TSMP disco key update from %v", srcIP)
if e.magicConn != nil {
e.magicConn.HandleDiscoKeyUpdate(srcIP, update)
}
}
e.tundev.GetDiscoPublicKey = func() key.DiscoPublic {
if e.magicConn == nil {
return key.DiscoPublic{}
}
return e.magicConn.DiscoPublicKey()
}
// Wire up TSMP disco key request sending to magicsock
if e.magicConn != nil {
e.magicConn.SetSendTSMPDiscoKeyRequest(e.sendTSMPDiscoKeyRequest)
}
// wgdev takes ownership of tundev, will close it when closed.
e.logf("Creating WireGuard device...")
e.wgdev = wgcfg.NewDevice(e.tundev, e.magicConn.Bind(), e.wgLogger.DeviceLogger)
@@ -1563,6 +1582,35 @@ func (e *userspaceEngine) setTSMPPongCallback(data [8]byte, cb func(packet.TSMPP
}
}
// sendTSMPDiscoKeyRequest sends a TSMP disco key request to the given peer IP.
func (e *userspaceEngine) sendTSMPDiscoKeyRequest(ip netip.Addr) error {
srcIP, err := e.mySelfIPMatchingFamily(ip)
if err != nil {
return err
}
var iph packet.Header
if srcIP.Is4() {
iph = packet.IP4Header{
IPProto: ipproto.TSMP,
Src: srcIP,
Dst: ip,
}
} else {
iph = packet.IP6Header{
IPProto: ipproto.TSMP,
Src: srcIP,
Dst: ip,
}
}
var tsmpPayload [1]byte
tsmpPayload[0] = byte(packet.TSMPTypeDiscoKeyRequest)
tsmpRequest := packet.Generate(iph, tsmpPayload[:])
return e.tundev.InjectOutbound(tsmpRequest)
}
func (e *userspaceEngine) setICMPEchoResponseCallback(idSeq uint32, cb func()) {
e.mu.Lock()
defer e.mu.Unlock()