wgengine/magicsock: introduce virtualNetworkID type (#16021)

This type improves code clarity and reduces the chance of heap alloc as
we pass it as a non-pointer. VNI being a 3-byte value enables us to
track set vs unset via the reserved/unused byte.

Updates tailscale/corp#27502

Signed-off-by: Jordan Whited <jordan@tailscale.com>
This commit is contained in:
Jordan Whited 2025-05-19 19:14:08 -07:00 committed by GitHub
parent 30a89ad378
commit 3cc80cce6a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 100 additions and 14 deletions

View File

@ -1112,7 +1112,7 @@ func (de *endpoint) sendDiscoPing(ep netip.AddrPort, discoKey key.DiscoPublic, t
size = min(size, MaxDiscoPingSize) size = min(size, MaxDiscoPingSize)
padding := max(size-discoPingSize, 0) padding := max(size-discoPingSize, 0)
sent, _ := de.c.sendDiscoMessage(ep, nil, de.publicKey, discoKey, &disco.Ping{ sent, _ := de.c.sendDiscoMessage(ep, virtualNetworkID{}, de.publicKey, discoKey, &disco.Ping{
TxID: [12]byte(txid), TxID: [12]byte(txid),
NodeKey: de.c.publicKeyAtomic.Load(), NodeKey: de.c.publicKeyAtomic.Load(),
Padding: padding, Padding: padding,

View File

@ -1603,16 +1603,43 @@ const (
// speeds. // speeds.
var debugIPv4DiscoPingPenalty = envknob.RegisterDuration("TS_DISCO_PONG_IPV4_DELAY") var debugIPv4DiscoPingPenalty = envknob.RegisterDuration("TS_DISCO_PONG_IPV4_DELAY")
// virtualNetworkID is a Geneve header (RFC8926) 3-byte virtual network
// identifier. Its field must only ever be accessed via its methods.
type virtualNetworkID struct {
_vni uint32
}
const (
vniSetMask uint32 = 0xFF000000
vniGetMask uint32 = ^vniSetMask
)
// isSet returns true if set() had been called previously, otherwise false.
func (v *virtualNetworkID) isSet() bool {
return v._vni&vniSetMask != 0
}
// set sets the provided VNI. If VNI exceeds the 3-byte storage it will be
// clamped.
func (v *virtualNetworkID) set(vni uint32) {
v._vni = vni | vniSetMask
}
// get returns the VNI value.
func (v *virtualNetworkID) get() uint32 {
return v._vni & vniGetMask
}
// sendDiscoMessage sends discovery message m to dstDisco at dst. // sendDiscoMessage sends discovery message m to dstDisco at dst.
// //
// If dst is a DERP IP:port, then dstKey must be non-zero. // If dst is a DERP IP:port, then dstKey must be non-zero.
// //
// If geneveVNI is non-nil, then the [disco.Message] will be preceded by a // If vni.isSet(), the [disco.Message] will be preceded by a Geneve header with
// Geneve header with the supplied VNI set. // the VNI field set to the value returned by vni.get().
// //
// The dstKey should only be non-zero if the dstDisco key // The dstKey should only be non-zero if the dstDisco key
// unambiguously maps to exactly one peer. // unambiguously maps to exactly one peer.
func (c *Conn) sendDiscoMessage(dst netip.AddrPort, geneveVNI *uint32, dstKey key.NodePublic, dstDisco key.DiscoPublic, m disco.Message, logLevel discoLogLevel) (sent bool, err error) { func (c *Conn) sendDiscoMessage(dst netip.AddrPort, vni virtualNetworkID, dstKey key.NodePublic, dstDisco key.DiscoPublic, m disco.Message, logLevel discoLogLevel) (sent bool, err error) {
isDERP := dst.Addr() == tailcfg.DerpMagicIPAddr isDERP := dst.Addr() == tailcfg.DerpMagicIPAddr
if _, isPong := m.(*disco.Pong); isPong && !isDERP && dst.Addr().Is4() { if _, isPong := m.(*disco.Pong); isPong && !isDERP && dst.Addr().Is4() {
time.Sleep(debugIPv4DiscoPingPenalty()) time.Sleep(debugIPv4DiscoPingPenalty())
@ -1651,11 +1678,11 @@ func (c *Conn) sendDiscoMessage(dst netip.AddrPort, geneveVNI *uint32, dstKey ke
c.mu.Unlock() c.mu.Unlock()
pkt := make([]byte, 0, 512) // TODO: size it correctly? pool? if it matters. pkt := make([]byte, 0, 512) // TODO: size it correctly? pool? if it matters.
if geneveVNI != nil { if vni.isSet() {
gh := packet.GeneveHeader{ gh := packet.GeneveHeader{
Version: 0, Version: 0,
Protocol: packet.GeneveProtocolDisco, Protocol: packet.GeneveProtocolDisco,
VNI: *geneveVNI, VNI: vni.get(),
Control: isRelayHandshakeMsg, Control: isRelayHandshakeMsg,
} }
pkt = append(pkt, make([]byte, packet.GeneveFixedHeaderLength)...) pkt = append(pkt, make([]byte, packet.GeneveFixedHeaderLength)...)
@ -1903,9 +1930,17 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc ke
switch dm := dm.(type) { switch dm := dm.(type) {
case *disco.Ping: case *disco.Ping:
metricRecvDiscoPing.Add(1) metricRecvDiscoPing.Add(1)
if isGeneveEncap {
// TODO(jwhited): handle Geneve-encapsulated disco ping.
return
}
c.handlePingLocked(dm, src, di, derpNodeSrc) c.handlePingLocked(dm, src, di, derpNodeSrc)
case *disco.Pong: case *disco.Pong:
metricRecvDiscoPong.Add(1) metricRecvDiscoPong.Add(1)
if isGeneveEncap {
// TODO(jwhited): handle Geneve-encapsulated disco pong.
return
}
// There might be multiple nodes for the sender's DiscoKey. // There might be multiple nodes for the sender's DiscoKey.
// Ask each to handle it, stopping once one reports that // Ask each to handle it, stopping once one reports that
// the Pong's TxID was theirs. // the Pong's TxID was theirs.
@ -2020,12 +2055,12 @@ func (c *Conn) handlePingLocked(dm *disco.Ping, src netip.AddrPort, di *discoInf
isDerp := src.Addr() == tailcfg.DerpMagicIPAddr isDerp := src.Addr() == tailcfg.DerpMagicIPAddr
// If we can figure out with certainty which node key this disco // If we can figure out with certainty which node key this disco
// message is for, eagerly update our IP<>node and disco<>node // message is for, eagerly update our IP:port<>node and disco<>node
// mappings to make p2p path discovery faster in simple // mappings to make p2p path discovery faster in simple
// cases. Without this, disco would still work, but would be // cases. Without this, disco would still work, but would be
// reliant on DERP call-me-maybe to establish the disco<>node // reliant on DERP call-me-maybe to establish the disco<>node
// mapping, and on subsequent disco handlePongConnLocked to establish // mapping, and on subsequent disco handlePongConnLocked to establish
// the IP<>disco mapping. // the IP:port<>disco mapping.
if nk, ok := c.unambiguousNodeKeyOfPingLocked(dm, di.discoKey, derpNodeSrc); ok { if nk, ok := c.unambiguousNodeKeyOfPingLocked(dm, di.discoKey, derpNodeSrc); ok {
if !isDerp { if !isDerp {
c.peerMap.setNodeKeyForIPPort(src, nk) c.peerMap.setNodeKeyForIPPort(src, nk)
@ -2086,7 +2121,7 @@ func (c *Conn) handlePingLocked(dm *disco.Ping, src netip.AddrPort, di *discoInf
ipDst := src ipDst := src
discoDest := di.discoKey discoDest := di.discoKey
go c.sendDiscoMessage(ipDst, nil, dstKey, discoDest, &disco.Pong{ go c.sendDiscoMessage(ipDst, virtualNetworkID{}, dstKey, discoDest, &disco.Pong{
TxID: dm.TxID, TxID: dm.TxID,
Src: src, Src: src,
}, discoVerboseLog) }, discoVerboseLog)
@ -2131,12 +2166,12 @@ func (c *Conn) enqueueCallMeMaybe(derpAddr netip.AddrPort, de *endpoint) {
for _, ep := range c.lastEndpoints { for _, ep := range c.lastEndpoints {
eps = append(eps, ep.Addr) eps = append(eps, ep.Addr)
} }
go de.c.sendDiscoMessage(derpAddr, nil, de.publicKey, epDisco.key, &disco.CallMeMaybe{MyNumber: eps}, discoLog) go de.c.sendDiscoMessage(derpAddr, virtualNetworkID{}, de.publicKey, epDisco.key, &disco.CallMeMaybe{MyNumber: eps}, discoLog)
if debugSendCallMeUnknownPeer() { if debugSendCallMeUnknownPeer() {
// Send a callMeMaybe packet to a non-existent peer // Send a callMeMaybe packet to a non-existent peer
unknownKey := key.NewNode().Public() unknownKey := key.NewNode().Public()
c.logf("magicsock: sending CallMeMaybe to unknown peer per TS_DEBUG_SEND_CALLME_UNKNOWN_PEER") c.logf("magicsock: sending CallMeMaybe to unknown peer per TS_DEBUG_SEND_CALLME_UNKNOWN_PEER")
go de.c.sendDiscoMessage(derpAddr, nil, unknownKey, epDisco.key, &disco.CallMeMaybe{MyNumber: eps}, discoLog) go de.c.sendDiscoMessage(derpAddr, virtualNetworkID{}, unknownKey, epDisco.key, &disco.CallMeMaybe{MyNumber: eps}, discoLog)
} }
} }

View File

@ -12,6 +12,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"math"
"math/rand" "math/rand"
"net" "net"
"net/http" "net/http"
@ -3317,3 +3318,52 @@ func Test_isDiscoMaybeGeneve(t *testing.T) {
}) })
} }
} }
func Test_virtualNetworkID(t *testing.T) {
tests := []struct {
name string
set *uint32
want uint32
}{
{
"don't set",
nil,
0,
},
{
"set 0",
ptr.To(uint32(0)),
0,
},
{
"set 1",
ptr.To(uint32(1)),
1,
},
{
"set math.MaxUint32",
ptr.To(uint32(math.MaxUint32)),
1<<24 - 1,
},
{
"set max 3-byte value",
ptr.To(uint32(1<<24 - 1)),
1<<24 - 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
v := virtualNetworkID{}
if tt.set != nil {
v.set(*tt.set)
}
if v.isSet() != (tt.set != nil) {
t.Fatalf("isSet: %v != wantIsSet: %v", v.isSet(), tt.set != nil)
}
if v.get() != tt.want {
t.Fatalf("get(): %v != want: %v", v.get(), tt.want)
}
})
}
}

View File

@ -16,7 +16,6 @@ import (
"tailscale.com/disco" "tailscale.com/disco"
udprelay "tailscale.com/net/udprelay/endpoint" udprelay "tailscale.com/net/udprelay/endpoint"
"tailscale.com/types/key" "tailscale.com/types/key"
"tailscale.com/types/ptr"
"tailscale.com/util/httpm" "tailscale.com/util/httpm"
"tailscale.com/util/set" "tailscale.com/util/set"
) )
@ -500,10 +499,12 @@ func (r *relayManager) handshakeServerEndpoint(work *relayHandshakeWork) {
sentBindAny := false sentBindAny := false
bind := &disco.BindUDPRelayEndpoint{} bind := &disco.BindUDPRelayEndpoint{}
vni := virtualNetworkID{}
vni.set(work.se.VNI)
for _, addrPort := range work.se.AddrPorts { for _, addrPort := range work.se.AddrPorts {
if addrPort.IsValid() { if addrPort.IsValid() {
sentBindAny = true sentBindAny = true
go work.ep.c.sendDiscoMessage(addrPort, ptr.To(work.se.VNI), key.NodePublic{}, work.se.ServerDisco, bind, discoLog) go work.ep.c.sendDiscoMessage(addrPort, vni, key.NodePublic{}, work.se.ServerDisco, bind, discoLog)
} }
} }
if !sentBindAny { if !sentBindAny {
@ -552,7 +553,7 @@ func (r *relayManager) handshakeServerEndpoint(work *relayHandshakeWork) {
// [udprelay.ServerEndpoint] from becoming fully operational. // [udprelay.ServerEndpoint] from becoming fully operational.
// 4. This is a singular tx with no roundtrip latency measurements // 4. This is a singular tx with no roundtrip latency measurements
// involved. // involved.
work.ep.c.sendDiscoMessage(challenge.from, ptr.To(work.se.VNI), key.NodePublic{}, work.se.ServerDisco, answer, discoLog) work.ep.c.sendDiscoMessage(challenge.from, vni, key.NodePublic{}, work.se.ServerDisco, answer, discoLog)
return return
case <-timer.C: case <-timer.C:
// The handshake timed out. // The handshake timed out.