mirror of
https://github.com/tailscale/tailscale.git
synced 2025-12-01 09:32:08 +00:00
net,wgengine: add support for disco key exchnage via TSMP
Updates tailscale/corp#34037 Signed-off-by: James Tucker <james@tailscale.com>
This commit is contained in:
@@ -18,7 +18,7 @@ import (
|
||||
"tailscale.com/types/ipproto"
|
||||
)
|
||||
|
||||
const minTSMPSize = 7 // the rejected body is 7 bytes
|
||||
const minTSMPSize = 1 // minimum is 1 byte for the type field (e.g., disco key request 'd')
|
||||
|
||||
// TailscaleRejectedHeader is a TSMP message that says that one
|
||||
// Tailscale node has rejected the connection from another. Unlike a
|
||||
@@ -72,6 +72,12 @@ const (
|
||||
|
||||
// TSMPTypePong is the type byte for a TailscalePongResponse.
|
||||
TSMPTypePong TSMPType = 'o'
|
||||
|
||||
// TSMPTypeDiscoKeyRequest is the type byte for a disco key request.
|
||||
TSMPTypeDiscoKeyRequest TSMPType = 'd'
|
||||
|
||||
// TSMPTypeDiscoKeyUpdate is the type byte for a disco key update.
|
||||
TSMPTypeDiscoKeyUpdate TSMPType = 'D'
|
||||
)
|
||||
|
||||
type TailscaleRejectReason byte
|
||||
@@ -259,3 +265,63 @@ func (h TSMPPongReply) Marshal(buf []byte) error {
|
||||
binary.BigEndian.PutUint16(buf[9:11], h.PeerAPIPort)
|
||||
return nil
|
||||
}
|
||||
|
||||
// TSMPDiscoKeyRequest is a TSMP message that requests a peer's disco key.
|
||||
//
|
||||
// On the wire, after the IP header, it's currently 1 byte:
|
||||
// - 'd' (TSMPTypeDiscoKeyRequest)
|
||||
type TSMPDiscoKeyRequest struct{}
|
||||
|
||||
func (pp *Parsed) AsTSMPDiscoKeyRequest() (h TSMPDiscoKeyRequest, ok bool) {
|
||||
if pp.IPProto != ipproto.TSMP {
|
||||
return
|
||||
}
|
||||
p := pp.Payload()
|
||||
if len(p) < 1 || p[0] != byte(TSMPTypeDiscoKeyRequest) {
|
||||
return
|
||||
}
|
||||
return h, true
|
||||
}
|
||||
|
||||
// TSMPDiscoKeyUpdate is a TSMP message that contains a disco public key.
|
||||
// It may be sent in response to a request, or unsolicited when a node
|
||||
// believes its peer may have stale disco key information.
|
||||
//
|
||||
// On the wire, after the IP header, it's currently 33 bytes:
|
||||
// - 'D' (TSMPTypeDiscoKeyUpdate)
|
||||
// - 32 bytes disco public key
|
||||
type TSMPDiscoKeyUpdate struct {
|
||||
IPHeader Header
|
||||
DiscoKey [32]byte // raw disco public key bytes
|
||||
}
|
||||
|
||||
// AsTSMPDiscoKeyUpdate returns pp as a TSMPDiscoKeyUpdate and whether it is one.
|
||||
// The update.IPHeader field is not populated.
|
||||
func (pp *Parsed) AsTSMPDiscoKeyUpdate() (update TSMPDiscoKeyUpdate, ok bool) {
|
||||
if pp.IPProto != ipproto.TSMP {
|
||||
return
|
||||
}
|
||||
p := pp.Payload()
|
||||
if len(p) < 33 || p[0] != byte(TSMPTypeDiscoKeyUpdate) {
|
||||
return
|
||||
}
|
||||
copy(update.DiscoKey[:], p[1:33])
|
||||
return update, true
|
||||
}
|
||||
|
||||
func (h TSMPDiscoKeyUpdate) Len() int {
|
||||
return h.IPHeader.Len() + 33
|
||||
}
|
||||
|
||||
func (h TSMPDiscoKeyUpdate) Marshal(buf []byte) error {
|
||||
if len(buf) < h.Len() {
|
||||
return errSmallBuffer
|
||||
}
|
||||
if err := h.IPHeader.Marshal(buf); err != nil {
|
||||
return err
|
||||
}
|
||||
buf = buf[h.IPHeader.Len():]
|
||||
buf[0] = byte(TSMPTypeDiscoKeyUpdate)
|
||||
copy(buf[1:33], h.DiscoKey[:])
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -71,3 +71,171 @@ func TestTailscaleRejectedHeader(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTSMPDiscoKeyRequest(t *testing.T) {
|
||||
t.Run("Manual", func(t *testing.T) {
|
||||
var payload [1]byte
|
||||
payload[0] = byte(TSMPTypeDiscoKeyRequest)
|
||||
|
||||
var p Parsed
|
||||
p.IPProto = TSMP
|
||||
p.dataofs = 40 // simulate after IP header
|
||||
buf := make([]byte, 40+1)
|
||||
copy(buf[40:], payload[:])
|
||||
p.b = buf
|
||||
p.length = len(buf)
|
||||
|
||||
_, ok := p.AsTSMPDiscoKeyRequest()
|
||||
if !ok {
|
||||
t.Fatal("failed to parse TSMP disco key request")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("RoundTripIPv4", func(t *testing.T) {
|
||||
src := netip.MustParseAddr("100.64.0.1")
|
||||
dst := netip.MustParseAddr("100.64.0.2")
|
||||
|
||||
iph := IP4Header{
|
||||
IPProto: TSMP,
|
||||
Src: src,
|
||||
Dst: dst,
|
||||
}
|
||||
|
||||
var payload [1]byte
|
||||
payload[0] = byte(TSMPTypeDiscoKeyRequest)
|
||||
|
||||
pkt := Generate(iph, payload[:])
|
||||
t.Logf("Generated packet: %d bytes, hex: %x", len(pkt), pkt)
|
||||
|
||||
// Manually check what decode4 would see
|
||||
if len(pkt) >= 4 {
|
||||
declaredLen := int(uint16(pkt[2])<<8 | uint16(pkt[3]))
|
||||
t.Logf("Packet buffer length: %d, IP header declares length: %d", len(pkt), declaredLen)
|
||||
t.Logf("Protocol byte at [9]: 0x%02x = %d", pkt[9], pkt[9])
|
||||
}
|
||||
|
||||
var p Parsed
|
||||
p.Decode(pkt)
|
||||
t.Logf("Decoded: IPVersion=%d IPProto=%v Src=%v Dst=%v length=%d dataofs=%d",
|
||||
p.IPVersion, p.IPProto, p.Src, p.Dst, p.length, p.dataofs)
|
||||
|
||||
if p.IPVersion != 4 {
|
||||
t.Errorf("IPVersion = %d, want 4", p.IPVersion)
|
||||
}
|
||||
if p.IPProto != TSMP {
|
||||
t.Errorf("IPProto = %v, want TSMP", p.IPProto)
|
||||
}
|
||||
if p.Src.Addr() != src {
|
||||
t.Errorf("Src = %v, want %v", p.Src.Addr(), src)
|
||||
}
|
||||
if p.Dst.Addr() != dst {
|
||||
t.Errorf("Dst = %v, want %v", p.Dst.Addr(), dst)
|
||||
}
|
||||
|
||||
_, ok := p.AsTSMPDiscoKeyRequest()
|
||||
if !ok {
|
||||
t.Fatal("failed to parse TSMP disco key request from generated packet")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("RoundTripIPv6", func(t *testing.T) {
|
||||
src := netip.MustParseAddr("2001:db8::1")
|
||||
dst := netip.MustParseAddr("2001:db8::2")
|
||||
|
||||
iph := IP6Header{
|
||||
IPProto: TSMP,
|
||||
Src: src,
|
||||
Dst: dst,
|
||||
}
|
||||
|
||||
var payload [1]byte
|
||||
payload[0] = byte(TSMPTypeDiscoKeyRequest)
|
||||
|
||||
pkt := Generate(iph, payload[:])
|
||||
t.Logf("Generated packet: %d bytes", len(pkt))
|
||||
|
||||
var p Parsed
|
||||
p.Decode(pkt)
|
||||
|
||||
if p.IPVersion != 6 {
|
||||
t.Errorf("IPVersion = %d, want 6", p.IPVersion)
|
||||
}
|
||||
if p.IPProto != TSMP {
|
||||
t.Errorf("IPProto = %v, want TSMP", p.IPProto)
|
||||
}
|
||||
if p.Src.Addr() != src {
|
||||
t.Errorf("Src = %v, want %v", p.Src.Addr(), src)
|
||||
}
|
||||
if p.Dst.Addr() != dst {
|
||||
t.Errorf("Dst = %v, want %v", p.Dst.Addr(), dst)
|
||||
}
|
||||
|
||||
_, ok := p.AsTSMPDiscoKeyRequest()
|
||||
if !ok {
|
||||
t.Fatal("failed to parse TSMP disco key request from generated packet")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestTSMPDiscoKeyUpdate(t *testing.T) {
|
||||
var discoKey [32]byte
|
||||
for i := range discoKey {
|
||||
discoKey[i] = byte(i + 10)
|
||||
}
|
||||
|
||||
// Test IPv4
|
||||
t.Run("IPv4", func(t *testing.T) {
|
||||
update := TSMPDiscoKeyUpdate{
|
||||
IPHeader: IP4Header{
|
||||
IPProto: TSMP,
|
||||
Src: netip.MustParseAddr("1.2.3.4"),
|
||||
Dst: netip.MustParseAddr("5.6.7.8"),
|
||||
},
|
||||
DiscoKey: discoKey,
|
||||
}
|
||||
|
||||
pkt := make([]byte, update.Len())
|
||||
if err := update.Marshal(pkt); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var p Parsed
|
||||
p.Decode(pkt)
|
||||
|
||||
parsed, ok := p.AsTSMPDiscoKeyUpdate()
|
||||
if !ok {
|
||||
t.Fatal("failed to parse TSMP disco key update")
|
||||
}
|
||||
if parsed.DiscoKey != discoKey {
|
||||
t.Errorf("disco key mismatch: got %v, want %v", parsed.DiscoKey, discoKey)
|
||||
}
|
||||
})
|
||||
|
||||
// Test IPv6
|
||||
t.Run("IPv6", func(t *testing.T) {
|
||||
update := TSMPDiscoKeyUpdate{
|
||||
IPHeader: IP6Header{
|
||||
IPProto: TSMP,
|
||||
Src: netip.MustParseAddr("2001:db8::1"),
|
||||
Dst: netip.MustParseAddr("2001:db8::2"),
|
||||
},
|
||||
DiscoKey: discoKey,
|
||||
}
|
||||
|
||||
pkt := make([]byte, update.Len())
|
||||
if err := update.Marshal(pkt); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var p Parsed
|
||||
p.Decode(pkt)
|
||||
|
||||
parsed, ok := p.AsTSMPDiscoKeyUpdate()
|
||||
if !ok {
|
||||
t.Fatal("failed to parse TSMP disco key update")
|
||||
}
|
||||
if parsed.DiscoKey != discoKey {
|
||||
t.Errorf("disco key mismatch: got %v, want %v", parsed.DiscoKey, discoKey)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -188,11 +188,19 @@ type Wrapper struct {
|
||||
// OnTSMPPongReceived, if non-nil, is called whenever a TSMP pong arrives.
|
||||
OnTSMPPongReceived func(packet.TSMPPongReply)
|
||||
|
||||
// OnTSMPDiscoKeyReceived, if non-nil, is called whenever a TSMP disco key update arrives.
|
||||
// The srcIP parameter identifies the peer that sent the update.
|
||||
OnTSMPDiscoKeyReceived func(srcIP netip.Addr, update packet.TSMPDiscoKeyUpdate)
|
||||
|
||||
// OnICMPEchoResponseReceived, if non-nil, is called whenever a ICMP echo response
|
||||
// arrives. If the packet is to be handled internally this returns true,
|
||||
// false otherwise.
|
||||
OnICMPEchoResponseReceived func(*packet.Parsed) bool
|
||||
|
||||
// GetDiscoPublicKey, if non-nil, returns the local node's disco public key.
|
||||
// This is called when responding to TSMP disco key requests.
|
||||
GetDiscoPublicKey func() key.DiscoPublic
|
||||
|
||||
// PeerAPIPort, if non-nil, returns the peerapi port that's
|
||||
// running for the given IP address.
|
||||
PeerAPIPort func(netip.Addr) (port uint16, ok bool)
|
||||
@@ -1132,6 +1140,15 @@ func (t *Wrapper) filterPacketInboundFromWireGuard(p *packet.Parsed, captHook pa
|
||||
if f := t.OnTSMPPongReceived; f != nil {
|
||||
f(data)
|
||||
}
|
||||
} else if _, ok := p.AsTSMPDiscoKeyRequest(); ok {
|
||||
t.noteActivity()
|
||||
t.injectOutboundDiscoKeyUpdate(p)
|
||||
return filter.DropSilently, gro
|
||||
} else if discoKeyUpdate, ok := p.AsTSMPDiscoKeyUpdate(); ok {
|
||||
if f := t.OnTSMPDiscoKeyReceived; f != nil {
|
||||
f(p.Src.Addr(), discoKeyUpdate)
|
||||
}
|
||||
return filter.DropSilently, gro
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1440,6 +1457,36 @@ func (t *Wrapper) injectOutboundPong(pp *packet.Parsed, req packet.TSMPPingReque
|
||||
t.InjectOutbound(packet.Generate(pong, nil))
|
||||
}
|
||||
|
||||
func (t *Wrapper) injectOutboundDiscoKeyUpdate(pp *packet.Parsed) {
|
||||
if t.GetDiscoPublicKey == nil {
|
||||
return
|
||||
}
|
||||
|
||||
discoKey := t.GetDiscoPublicKey()
|
||||
if discoKey.IsZero() {
|
||||
return
|
||||
}
|
||||
|
||||
update := packet.TSMPDiscoKeyUpdate{
|
||||
DiscoKey: discoKey.Raw32(),
|
||||
}
|
||||
|
||||
switch pp.IPVersion {
|
||||
case 4:
|
||||
h4 := pp.IP4Header()
|
||||
h4.ToResponse()
|
||||
update.IPHeader = h4
|
||||
case 6:
|
||||
h6 := pp.IP6Header()
|
||||
h6.ToResponse()
|
||||
update.IPHeader = h6
|
||||
default:
|
||||
return
|
||||
}
|
||||
|
||||
t.InjectOutbound(packet.Generate(update, nil))
|
||||
}
|
||||
|
||||
// InjectOutbound makes the Wrapper device behave as if a packet
|
||||
// with the given contents was sent to the network.
|
||||
// It does not block, but takes ownership of the packet.
|
||||
|
||||
@@ -721,6 +721,16 @@ func (c *Conn) processDERPReadResult(dm derpReadResult, b []byte) (n int, ep *en
|
||||
update(0, netip.AddrPortFrom(ep.nodeAddr, 0), srcAddr.ap, 1, dm.n, true)
|
||||
}
|
||||
|
||||
// Request disco key from peer via TSMP if we receive a WireGuard handshake
|
||||
// over DERP without recent disco success. This handles the "WireGuard-first"
|
||||
// case where WireGuard establishes a tunnel via DERP before disco succeeds
|
||||
// (e.g., control plane unreachable or stale disco keys).
|
||||
// We only trigger on data packets (not handshake packets) because the tunnel
|
||||
// must be fully established before we can send TSMP requests through it.
|
||||
if looksLikeWireGuardHandshake(b[:n]) && n > 0 {
|
||||
go c.requestDiscoKeyViaTSMP(dm.src, ep)
|
||||
}
|
||||
|
||||
c.metrics.inboundPacketsDERPTotal.Add(1)
|
||||
c.metrics.inboundBytesDERPTotal.Add(int64(n))
|
||||
return n, ep
|
||||
|
||||
@@ -155,17 +155,18 @@ type Conn struct {
|
||||
// This block mirrors the contents and field order of the Options
|
||||
// struct. Initialized once at construction, then constant.
|
||||
|
||||
eventBus *eventbus.Bus
|
||||
eventClient *eventbus.Client
|
||||
logf logger.Logf
|
||||
epFunc func([]tailcfg.Endpoint)
|
||||
derpActiveFunc func()
|
||||
idleFunc func() time.Duration // nil means unknown
|
||||
testOnlyPacketListener nettype.PacketListener
|
||||
noteRecvActivity func(key.NodePublic) // or nil, see Options.NoteRecvActivity
|
||||
netMon *netmon.Monitor // must be non-nil
|
||||
health *health.Tracker // or nil
|
||||
controlKnobs *controlknobs.Knobs // or nil
|
||||
eventBus *eventbus.Bus
|
||||
eventClient *eventbus.Client
|
||||
logf logger.Logf
|
||||
epFunc func([]tailcfg.Endpoint)
|
||||
derpActiveFunc func()
|
||||
idleFunc func() time.Duration // nil means unknown
|
||||
testOnlyPacketListener nettype.PacketListener
|
||||
noteRecvActivity func(key.NodePublic) // or nil, see Options.NoteRecvActivity
|
||||
sendTSMPDiscoKeyRequest func(netip.Addr) error // or nil, sends TSMP disco key request to peer
|
||||
netMon *netmon.Monitor // must be non-nil
|
||||
health *health.Tracker // or nil
|
||||
controlKnobs *controlknobs.Knobs // or nil
|
||||
|
||||
// ================================================================
|
||||
// No locking required to access these fields, either because
|
||||
@@ -1800,6 +1801,15 @@ func looksLikeInitiationMsg(b []byte) bool {
|
||||
binary.LittleEndian.Uint32(b) == device.MessageInitiationType
|
||||
}
|
||||
|
||||
func looksLikeWireGuardHandshake(b []byte) bool {
|
||||
if len(b) < 4 {
|
||||
return false
|
||||
}
|
||||
msgType := binary.LittleEndian.Uint32(b)
|
||||
return (len(b) == device.MessageInitiationSize && msgType == device.MessageInitiationType) ||
|
||||
(len(b) == device.MessageResponseSize && msgType == device.MessageResponseType)
|
||||
}
|
||||
|
||||
// receiveIP is the shared bits of ReceiveIPv4 and ReceiveIPv6.
|
||||
//
|
||||
// size is the length of 'b' to report up to wireguard-go (only relevant if
|
||||
@@ -2857,6 +2867,14 @@ func (c *Conn) SetSilentDisco(v bool) {
|
||||
})
|
||||
}
|
||||
|
||||
// SetSendTSMPDiscoKeyRequest sets the callback function to send TSMP disco key requests.
|
||||
// This is provided by the engine/tundev to inject TSMP packets.
|
||||
func (c *Conn) SetSendTSMPDiscoKeyRequest(fn func(netip.Addr) error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.sendTSMPDiscoKeyRequest = fn
|
||||
}
|
||||
|
||||
// SilentDisco returns true if silent disco is enabled, otherwise false.
|
||||
func (c *Conn) SilentDisco() bool {
|
||||
c.mu.Lock()
|
||||
@@ -4104,6 +4122,13 @@ var (
|
||||
metricUDPLifetimeCycleCompleteAt10sCliff = newUDPLifetimeCounter("magicsock_udp_lifetime_cycle_complete_at_10s_cliff")
|
||||
metricUDPLifetimeCycleCompleteAt30sCliff = newUDPLifetimeCounter("magicsock_udp_lifetime_cycle_complete_at_30s_cliff")
|
||||
metricUDPLifetimeCycleCompleteAt60sCliff = newUDPLifetimeCounter("magicsock_udp_lifetime_cycle_complete_at_60s_cliff")
|
||||
|
||||
// TSMP disco key exchange
|
||||
metricTSMPDiscoKeyRequestSent = clientmetric.NewCounter("magicsock_tsmp_disco_key_request_sent")
|
||||
metricTSMPDiscoKeyRequestError = clientmetric.NewCounter("magicsock_tsmp_disco_key_request_error")
|
||||
metricTSMPDiscoKeyUpdateReceived = clientmetric.NewCounter("magicsock_tsmp_disco_key_update_received")
|
||||
metricTSMPDiscoKeyUpdateApplied = clientmetric.NewCounter("magicsock_tsmp_disco_key_update_applied")
|
||||
metricTSMPDiscoKeyUpdateUnknown = clientmetric.NewCounter("magicsock_tsmp_disco_key_update_unknown_peer")
|
||||
)
|
||||
|
||||
// newUDPLifetimeCounter returns a new *clientmetric.Metric with the provided
|
||||
@@ -4242,6 +4267,101 @@ func (le *lazyEndpoint) FromPeer(peerPublicKey [32]byte) {
|
||||
// See http://go/corp/29422 & http://go/corp/30042
|
||||
le.c.peerMap.setNodeKeyForEpAddr(le.src, pubKey)
|
||||
le.c.logf("magicsock: lazyEndpoint.FromPeer(%v) setting epAddr(%v) in peerMap for node(%v)", pubKey.ShortString(), le.src, ep.nodeAddr)
|
||||
|
||||
// Request disco key from peer via TSMP if we establish a tunnel
|
||||
// without a recent disco ping. This handles cases where WireGuard
|
||||
// establishes a tunnel before disco succeeds (e.g., control plane
|
||||
// unreachable or stale disco keys).
|
||||
go le.c.requestDiscoKeyViaTSMP(pubKey, ep)
|
||||
}
|
||||
|
||||
// requestDiscoKeyViaTSMP sends a TSMP disco key request to a peer if there
|
||||
// hasn't been a recent disco ping.
|
||||
func (c *Conn) requestDiscoKeyViaTSMP(nodeKey key.NodePublic, ep *endpoint) {
|
||||
if c.sendTSMPDiscoKeyRequest == nil {
|
||||
return
|
||||
}
|
||||
if !ep.nodeAddr.IsValid() {
|
||||
return
|
||||
}
|
||||
|
||||
epDisco := ep.disco.Load()
|
||||
if epDisco != nil {
|
||||
c.mu.Lock()
|
||||
di := c.discoInfo[epDisco.key]
|
||||
recentDiscoPing := di != nil && time.Since(di.lastPingTime) < discoPingInterval
|
||||
c.mu.Unlock()
|
||||
|
||||
if recentDiscoPing {
|
||||
return
|
||||
}
|
||||
}
|
||||
// YUCK. once again goroutines fight back - we need to deterministically
|
||||
// schedule this _after_ the wireguard handshake response or else we trigger
|
||||
// the wireguard handshake race problem. Maybe it's ok though, as we should
|
||||
// really be singleflighting this, and perhaps we just use a singleflight
|
||||
// with a short cork.
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
c.logf("magicsock: sending TSMP disco key request to %v (%v)", nodeKey.ShortString(), ep.nodeAddr)
|
||||
if err := c.sendTSMPDiscoKeyRequest(ep.nodeAddr); err != nil {
|
||||
c.logf("magicsock: failed to send TSMP disco key request: %v", err)
|
||||
metricTSMPDiscoKeyRequestError.Add(1)
|
||||
return
|
||||
}
|
||||
metricTSMPDiscoKeyRequestSent.Add(1)
|
||||
}
|
||||
|
||||
// HandleDiscoKeyUpdate processes a TSMP disco key update.
|
||||
// The update may be solicited (in response to a request) or unsolicited.
|
||||
// srcIP is the Tailscale IP of the peer that sent the update.
|
||||
func (c *Conn) HandleDiscoKeyUpdate(srcIP netip.Addr, update packet.TSMPDiscoKeyUpdate) {
|
||||
discoKey := key.DiscoPublicFromRaw32(mem.B(update.DiscoKey[:]))
|
||||
c.logf("magicsock: received disco key update %v from %v", discoKey.ShortString(), srcIP)
|
||||
metricTSMPDiscoKeyUpdateReceived.Add(1)
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
var nodeKey key.NodePublic
|
||||
var found bool
|
||||
for _, peer := range c.peers.All() {
|
||||
for _, addr := range peer.Addresses().All() {
|
||||
if addr.Addr() == srcIP {
|
||||
nodeKey = peer.Key()
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if found {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
c.logf("magicsock: disco key update from unknown peer %v", srcIP)
|
||||
metricTSMPDiscoKeyUpdateUnknown.Add(1)
|
||||
return
|
||||
}
|
||||
|
||||
ep, ok := c.peerMap.endpointForNodeKey(nodeKey)
|
||||
if !ok {
|
||||
c.logf("magicsock: endpoint not found for node %v", nodeKey.ShortString())
|
||||
return
|
||||
}
|
||||
|
||||
oldDiscoKey := key.DiscoPublic{}
|
||||
if epDisco := ep.disco.Load(); epDisco != nil {
|
||||
oldDiscoKey = epDisco.key
|
||||
}
|
||||
c.discoInfoForKnownPeerLocked(discoKey)
|
||||
ep.disco.Store(&endpointDisco{
|
||||
key: discoKey,
|
||||
short: discoKey.ShortString(),
|
||||
})
|
||||
c.peerMap.upsertEndpoint(ep, oldDiscoKey)
|
||||
c.logf("magicsock: updated disco key for peer %v to %v", nodeKey.ShortString(), discoKey.ShortString())
|
||||
metricTSMPDiscoKeyUpdateApplied.Add(1)
|
||||
}
|
||||
|
||||
// PeerRelays returns the current set of candidate peer relays.
|
||||
|
||||
@@ -64,6 +64,7 @@ import (
|
||||
"tailscale.com/types/netmap"
|
||||
"tailscale.com/types/nettype"
|
||||
"tailscale.com/types/ptr"
|
||||
"tailscale.com/types/views"
|
||||
"tailscale.com/util/cibuild"
|
||||
"tailscale.com/util/clientmetric"
|
||||
"tailscale.com/util/eventbus"
|
||||
@@ -4305,3 +4306,58 @@ func TestRotateDiscoKeyMultipleTimes(t *testing.T) {
|
||||
keys = append(keys, newKey)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendTSMPDiscoKeyRequest(t *testing.T) {
|
||||
ep := &endpoint{
|
||||
nodeID: 1,
|
||||
publicKey: key.NewNode().Public(),
|
||||
nodeAddr: netip.MustParseAddr("100.64.0.1"),
|
||||
}
|
||||
discoKey := key.NewDisco().Public()
|
||||
ep.disco.Store(&endpointDisco{
|
||||
key: discoKey,
|
||||
short: discoKey.ShortString(),
|
||||
})
|
||||
conn := newConn(t.Logf)
|
||||
ep.c = conn
|
||||
|
||||
tsmpRequestCalled := make(chan struct{}, 1)
|
||||
var capturedIP netip.Addr
|
||||
conn.sendTSMPDiscoKeyRequest = func(ip netip.Addr) error {
|
||||
capturedIP = ip
|
||||
tsmpRequestCalled <- struct{}{}
|
||||
return nil
|
||||
}
|
||||
|
||||
conn.mu.Lock()
|
||||
conn.peers = views.SliceOf([]tailcfg.NodeView{
|
||||
(&tailcfg.Node{
|
||||
Key: ep.publicKey,
|
||||
Addresses: []netip.Prefix{
|
||||
netip.MustParsePrefix("100.64.0.1/32"),
|
||||
},
|
||||
}).View(),
|
||||
})
|
||||
conn.mu.Unlock()
|
||||
|
||||
var pubKey [32]byte
|
||||
copy(pubKey[:], ep.publicKey.AppendTo(nil))
|
||||
conn.peerMap.upsertEndpoint(ep, key.DiscoPublic{})
|
||||
|
||||
le := &lazyEndpoint{
|
||||
c: conn,
|
||||
src: epAddr{ap: netip.MustParseAddrPort("127.0.0.1:7777")},
|
||||
}
|
||||
|
||||
le.FromPeer(pubKey)
|
||||
|
||||
select {
|
||||
case <-tsmpRequestCalled:
|
||||
if !capturedIP.IsValid() {
|
||||
t.Error("TSMP request sent with invalid IP")
|
||||
}
|
||||
t.Logf("TSMP disco key request sent to %v", capturedIP)
|
||||
case <-time.After(time.Second):
|
||||
t.Error("TSMP disco key request was not sent")
|
||||
}
|
||||
}
|
||||
|
||||
291
wgengine/magicsock/tsmp_disco_test.go
Normal file
291
wgengine/magicsock/tsmp_disco_test.go
Normal file
@@ -0,0 +1,291 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package magicsock
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/tailscale/wireguard-go/tun/tuntest"
|
||||
"tailscale.com/net/netaddr"
|
||||
"tailscale.com/net/packet"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/tstest"
|
||||
"tailscale.com/types/ipproto"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/netmap"
|
||||
"tailscale.com/util/set"
|
||||
"tailscale.com/wgengine/wgcfg/nmcfg"
|
||||
)
|
||||
|
||||
func TestTSMPDiscoKeyExchange(t *testing.T) {
|
||||
tstest.ResourceCheck(t)
|
||||
|
||||
// Set up DERP and STUN servers
|
||||
derpMap, cleanup := runDERPAndStun(t, t.Logf, localhostListener{}, netaddr.IPv4(127, 0, 0, 1))
|
||||
defer cleanup()
|
||||
|
||||
// Create two magicsock peers
|
||||
m1 := newMagicStack(t, t.Logf, localhostListener{}, derpMap)
|
||||
defer m1.Close()
|
||||
m2 := newMagicStack(t, t.Logf, localhostListener{}, derpMap)
|
||||
defer m2.Close()
|
||||
|
||||
// Wire up TSMP hooks to enable disco key exchange
|
||||
// This mimics what userspaceEngine does in wgengine/userspace.go
|
||||
|
||||
// Hook 0: GetDiscoPublicKey - allows TSMP replies to include current disco key
|
||||
m1.tsTun.GetDiscoPublicKey = m1.conn.DiscoPublicKey
|
||||
m2.tsTun.GetDiscoPublicKey = m2.conn.DiscoPublicKey
|
||||
|
||||
// Hook 1: OnTSMPDiscoKeyReceived - handle incoming TSMP disco key updates
|
||||
m1.tsTun.OnTSMPDiscoKeyReceived = func(srcIP netip.Addr, update packet.TSMPDiscoKeyUpdate) {
|
||||
t.Logf("m1: received TSMP disco key update from %v", srcIP)
|
||||
m1.conn.HandleDiscoKeyUpdate(srcIP, update)
|
||||
}
|
||||
m2.tsTun.OnTSMPDiscoKeyReceived = func(srcIP netip.Addr, update packet.TSMPDiscoKeyUpdate) {
|
||||
t.Logf("m2: received TSMP disco key update from %v", srcIP)
|
||||
m2.conn.HandleDiscoKeyUpdate(srcIP, update)
|
||||
}
|
||||
|
||||
sendTSMPDiscoKeyRequest := func(dstIP netip.Addr) error {
|
||||
var srcIP netip.Addr
|
||||
var stack *magicStack
|
||||
|
||||
switch dstIP {
|
||||
case m1.IP():
|
||||
srcIP = m2.IP()
|
||||
stack = m2
|
||||
t.Logf("m2: sending disco key request to m1")
|
||||
case m2.IP():
|
||||
srcIP = m1.IP()
|
||||
stack = m1
|
||||
t.Logf("m1: sending disco key request to m2")
|
||||
}
|
||||
|
||||
// equivalent to the implementation in userspace.Engine
|
||||
iph := packet.IP4Header{
|
||||
IPProto: ipproto.TSMP,
|
||||
Src: srcIP,
|
||||
Dst: dstIP,
|
||||
}
|
||||
|
||||
var tsmpPayload [1]byte
|
||||
tsmpPayload[0] = byte(packet.TSMPTypeDiscoKeyRequest)
|
||||
|
||||
tsmpRequest := packet.Generate(iph, tsmpPayload[:])
|
||||
return stack.tsTun.InjectOutbound(tsmpRequest)
|
||||
}
|
||||
|
||||
// Hook 2: SetSendTSMPDiscoKeyRequest - send TSMP disco key requests
|
||||
m1.conn.SetSendTSMPDiscoKeyRequest(sendTSMPDiscoKeyRequest)
|
||||
m2.conn.SetSendTSMPDiscoKeyRequest(sendTSMPDiscoKeyRequest)
|
||||
|
||||
// Get initial disco keys
|
||||
disco1Original := m1.conn.DiscoPublicKey()
|
||||
disco2 := m2.conn.DiscoPublicKey()
|
||||
|
||||
t.Logf("m1: node=%v disco=%v", m1.Public().ShortString(), disco1Original.ShortString())
|
||||
t.Logf("m2: node=%v disco=%v", m2.Public().ShortString(), disco2.ShortString())
|
||||
|
||||
// Wait for initial endpoints
|
||||
var eps1, eps2 []tailcfg.Endpoint
|
||||
select {
|
||||
case eps1 = <-m1.epCh:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for m1 endpoints")
|
||||
}
|
||||
select {
|
||||
case eps2 = <-m2.epCh:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for m2 endpoints")
|
||||
}
|
||||
|
||||
// Build initial network maps and establish connection
|
||||
nm1 := &netmap.NetworkMap{
|
||||
NodeKey: m1.Public(),
|
||||
SelfNode: (&tailcfg.Node{
|
||||
Addresses: []netip.Prefix{netip.PrefixFrom(netaddr.IPv4(100, 64, 0, 1), 32)},
|
||||
}).View(),
|
||||
Peers: []tailcfg.NodeView{
|
||||
(&tailcfg.Node{
|
||||
ID: 2,
|
||||
Key: m2.Public(),
|
||||
DiscoKey: disco2,
|
||||
Addresses: []netip.Prefix{netip.PrefixFrom(netaddr.IPv4(100, 64, 0, 2), 32)},
|
||||
AllowedIPs: []netip.Prefix{netip.PrefixFrom(netaddr.IPv4(100, 64, 0, 2), 32)},
|
||||
Endpoints: epFromTyped(eps2),
|
||||
HomeDERP: 1,
|
||||
}).View(),
|
||||
},
|
||||
}
|
||||
|
||||
nm2 := &netmap.NetworkMap{
|
||||
NodeKey: m2.Public(),
|
||||
SelfNode: (&tailcfg.Node{
|
||||
Addresses: []netip.Prefix{netip.PrefixFrom(netaddr.IPv4(100, 64, 0, 2), 32)},
|
||||
}).View(),
|
||||
Peers: []tailcfg.NodeView{
|
||||
(&tailcfg.Node{
|
||||
ID: 1,
|
||||
Key: m1.Public(),
|
||||
DiscoKey: disco1Original,
|
||||
Addresses: []netip.Prefix{netip.PrefixFrom(netaddr.IPv4(100, 64, 0, 1), 32)},
|
||||
AllowedIPs: []netip.Prefix{netip.PrefixFrom(netaddr.IPv4(100, 64, 0, 1), 32)},
|
||||
Endpoints: epFromTyped(eps1),
|
||||
HomeDERP: 1,
|
||||
}).View(),
|
||||
},
|
||||
}
|
||||
|
||||
cfg1, err := nmcfg.WGCfg(m1.privateKey, nm1, t.Logf, 0, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cfg2, err := nmcfg.WGCfg(m2.privateKey, nm2, t.Logf, 0, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
nv1 := NodeViewsUpdate{
|
||||
SelfNode: nm1.SelfNode,
|
||||
Peers: nm1.Peers,
|
||||
}
|
||||
m1.conn.onNodeViewsUpdate(nv1)
|
||||
|
||||
peerSet1 := set.Set[key.NodePublic]{}
|
||||
peerSet1.Add(m2.Public())
|
||||
m1.conn.UpdatePeers(peerSet1)
|
||||
|
||||
nv2 := NodeViewsUpdate{
|
||||
SelfNode: nm2.SelfNode,
|
||||
Peers: nm2.Peers,
|
||||
}
|
||||
m2.conn.onNodeViewsUpdate(nv2)
|
||||
|
||||
peerSet2 := set.Set[key.NodePublic]{}
|
||||
peerSet2.Add(m1.Public())
|
||||
m2.conn.UpdatePeers(peerSet2)
|
||||
|
||||
if err := m1.Reconfig(cfg1); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := m2.Reconfig(cfg2); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Logf("=== INITIAL CONFIGURATION COMPLETE ===")
|
||||
|
||||
// Start goroutines to drain TUN inbound channels so TSMP packets can be received
|
||||
drainTun := func(name string, stack *magicStack) {
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-t.Context().Done():
|
||||
return
|
||||
case pkt := <-stack.tun.Inbound:
|
||||
var p packet.Parsed
|
||||
p.Decode(pkt)
|
||||
if p.IPProto == ipproto.TSMP {
|
||||
t.Logf("%s: received TSMP packet on TUN inbound: %d bytes", name, len(pkt))
|
||||
} else if p.IPProto == ipproto.ICMPv4 {
|
||||
t.Logf("%s: received ICMPv4 packet on TUN inbound: %d bytes", name, len(pkt))
|
||||
} else {
|
||||
t.Logf("%s: received packet on TUN inbound: %d bytes, proto=%v", name, len(pkt), p.IPProto)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
drainTun("m1", m1)
|
||||
drainTun("m2", m2)
|
||||
|
||||
initialRequestsSent := metricTSMPDiscoKeyRequestSent.Value()
|
||||
initialUpdatesReceived := metricTSMPDiscoKeyUpdateReceived.Value()
|
||||
initialUpdatesApplied := metricTSMPDiscoKeyUpdateApplied.Value()
|
||||
|
||||
t.Logf("Initial metrics: requests_sent=%d updates_received=%d updates_applied=%d",
|
||||
initialRequestsSent, initialUpdatesReceived, initialUpdatesApplied)
|
||||
|
||||
t.Logf("=== ROTATING m1's DISCO KEY ===")
|
||||
m1.conn.RotateDiscoKey()
|
||||
disco1New := m1.conn.DiscoPublicKey()
|
||||
|
||||
if disco1Original.Compare(disco1New) == 0 {
|
||||
t.Fatal("disco key failed to rotate")
|
||||
}
|
||||
t.Logf("Rotated: %v -> %v", disco1Original.ShortString(), disco1New.ShortString())
|
||||
|
||||
t.Logf("=== SENDING PACKETS TO TRIGGER TSMP EXCHANGE ===")
|
||||
|
||||
ping1to2 := tuntest.Ping(netip.MustParseAddr("100.64.0.2"), netip.MustParseAddr("100.64.0.1"))
|
||||
|
||||
// Send packets from m2 to m1 only - this will trigger m1's handshake initiation
|
||||
// and when m2 receives the encrypted packet, it should trigger FromPeer -> TSMP
|
||||
select {
|
||||
case m1.tun.Outbound <- ping1to2:
|
||||
default:
|
||||
}
|
||||
|
||||
for {
|
||||
time.Sleep(time.Millisecond)
|
||||
// Check if m2 has learned m1's new disco key
|
||||
st := m2.Status()
|
||||
if ps, ok := st.Peer[m1.Public()]; ok && ps.CurAddr != "" {
|
||||
t.Logf("Connection established after disco key rotation")
|
||||
t.Logf("m2 -> m1 via %v", ps.CurAddr)
|
||||
t.Logf("Disco key rotation: %v -> %v", disco1Original.ShortString(), disco1New.ShortString())
|
||||
|
||||
// Verify TSMP metrics incremented
|
||||
finalRequestsSent := metricTSMPDiscoKeyRequestSent.Value()
|
||||
finalUpdatesReceived := metricTSMPDiscoKeyUpdateReceived.Value()
|
||||
finalUpdatesApplied := metricTSMPDiscoKeyUpdateApplied.Value()
|
||||
|
||||
t.Logf("Final metrics: requests_sent=%d updates_received=%d updates_applied=%d",
|
||||
finalRequestsSent, finalUpdatesReceived, finalUpdatesApplied)
|
||||
|
||||
// Check that at least one TSMP request was sent
|
||||
if finalRequestsSent <= initialRequestsSent {
|
||||
t.Errorf("Expected TSMP disco key request to be sent, but metric did not increment: %d -> %d",
|
||||
initialRequestsSent, finalRequestsSent)
|
||||
} else {
|
||||
t.Logf("✓ TSMP disco key request sent (metric: %d -> %d)",
|
||||
initialRequestsSent, finalRequestsSent)
|
||||
}
|
||||
|
||||
// Check that at least one TSMP update was received
|
||||
if finalUpdatesReceived <= initialUpdatesReceived {
|
||||
t.Errorf("Expected TSMP disco key update to be received, but metric did not increment: %d -> %d",
|
||||
initialUpdatesReceived, finalUpdatesReceived)
|
||||
} else {
|
||||
t.Logf("✓ TSMP disco key update received (metric: %d -> %d)",
|
||||
initialUpdatesReceived, finalUpdatesReceived)
|
||||
}
|
||||
|
||||
// Check that at least one TSMP update was applied
|
||||
if finalUpdatesApplied <= initialUpdatesApplied {
|
||||
t.Errorf("Expected TSMP disco key update to be applied, but metric did not increment: %d -> %d",
|
||||
initialUpdatesApplied, finalUpdatesApplied)
|
||||
} else {
|
||||
t.Logf("✓ TSMP disco key update applied (metric: %d -> %d)",
|
||||
initialUpdatesApplied, finalUpdatesApplied)
|
||||
}
|
||||
|
||||
// Verify error counter didn't increment
|
||||
requestErrors := metricTSMPDiscoKeyRequestError.Value()
|
||||
if requestErrors > 0 {
|
||||
t.Logf("Warning: TSMP disco key request errors: %d", requestErrors)
|
||||
}
|
||||
|
||||
unknownPeers := metricTSMPDiscoKeyUpdateUnknown.Value()
|
||||
if unknownPeers > 0 {
|
||||
t.Logf("Warning: TSMP disco key updates from unknown peers: %d", unknownPeers)
|
||||
}
|
||||
|
||||
t.Logf("TSMP disco key exchange infrastructure is functional")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -466,6 +466,25 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error)
|
||||
return true
|
||||
}
|
||||
|
||||
e.tundev.OnTSMPDiscoKeyReceived = func(srcIP netip.Addr, update packet.TSMPDiscoKeyUpdate) {
|
||||
e.logf("wgengine: got TSMP disco key update from %v", srcIP)
|
||||
if e.magicConn != nil {
|
||||
e.magicConn.HandleDiscoKeyUpdate(srcIP, update)
|
||||
}
|
||||
}
|
||||
|
||||
e.tundev.GetDiscoPublicKey = func() key.DiscoPublic {
|
||||
if e.magicConn == nil {
|
||||
return key.DiscoPublic{}
|
||||
}
|
||||
return e.magicConn.DiscoPublicKey()
|
||||
}
|
||||
|
||||
// Wire up TSMP disco key request sending to magicsock
|
||||
if e.magicConn != nil {
|
||||
e.magicConn.SetSendTSMPDiscoKeyRequest(e.sendTSMPDiscoKeyRequest)
|
||||
}
|
||||
|
||||
// wgdev takes ownership of tundev, will close it when closed.
|
||||
e.logf("Creating WireGuard device...")
|
||||
e.wgdev = wgcfg.NewDevice(e.tundev, e.magicConn.Bind(), e.wgLogger.DeviceLogger)
|
||||
@@ -1563,6 +1582,35 @@ func (e *userspaceEngine) setTSMPPongCallback(data [8]byte, cb func(packet.TSMPP
|
||||
}
|
||||
}
|
||||
|
||||
// sendTSMPDiscoKeyRequest sends a TSMP disco key request to the given peer IP.
|
||||
func (e *userspaceEngine) sendTSMPDiscoKeyRequest(ip netip.Addr) error {
|
||||
srcIP, err := e.mySelfIPMatchingFamily(ip)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var iph packet.Header
|
||||
if srcIP.Is4() {
|
||||
iph = packet.IP4Header{
|
||||
IPProto: ipproto.TSMP,
|
||||
Src: srcIP,
|
||||
Dst: ip,
|
||||
}
|
||||
} else {
|
||||
iph = packet.IP6Header{
|
||||
IPProto: ipproto.TSMP,
|
||||
Src: srcIP,
|
||||
Dst: ip,
|
||||
}
|
||||
}
|
||||
|
||||
var tsmpPayload [1]byte
|
||||
tsmpPayload[0] = byte(packet.TSMPTypeDiscoKeyRequest)
|
||||
|
||||
tsmpRequest := packet.Generate(iph, tsmpPayload[:])
|
||||
return e.tundev.InjectOutbound(tsmpRequest)
|
||||
}
|
||||
|
||||
func (e *userspaceEngine) setICMPEchoResponseCallback(idSeq uint32, cb func()) {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
|
||||
Reference in New Issue
Block a user