wgengine/magicsock: add receiveIP() unit tests (#16781)

One of these tests highlighted a Geneve encap bug, which is also fixed
in this commit.

looksLikeInitMsg was passed a packet post Geneve header stripping with
slice offsets that had not been updated to account for the stripping.

Updates tailscale/corp#30903

Signed-off-by: Jordan Whited <jordan@tailscale.com>
This commit is contained in:
Jordan Whited
2025-08-06 09:35:25 -07:00
committed by GitHub
parent 57d653014b
commit 908f20e0a5
2 changed files with 319 additions and 1 deletions

View File

@@ -1823,6 +1823,9 @@ func (c *Conn) receiveIP(b []byte, ipp netip.AddrPort, cache *epAddrEndpointCach
return nil, 0, false, false return nil, 0, false, false
} }
// geneveInclusivePacketLen holds the packet length prior to any potential
// Geneve header stripping.
geneveInclusivePacketLen := len(b)
if src.vni.isSet() { if src.vni.isSet() {
// Strip away the Geneve header before returning the packet to // Strip away the Geneve header before returning the packet to
// wireguard-go. // wireguard-go.
@@ -1831,6 +1834,7 @@ func (c *Conn) receiveIP(b []byte, ipp netip.AddrPort, cache *epAddrEndpointCach
// to support returning start offset in order to get rid of this memmove perf // to support returning start offset in order to get rid of this memmove perf
// penalty. // penalty.
size = copy(b, b[packet.GeneveFixedHeaderLength:]) size = copy(b, b[packet.GeneveFixedHeaderLength:])
b = b[:size]
} }
if cache.epAddr == src && cache.de != nil && cache.gen == cache.de.numStopAndReset() { if cache.epAddr == src && cache.de != nil && cache.gen == cache.de.numStopAndReset() {
@@ -1859,7 +1863,7 @@ func (c *Conn) receiveIP(b []byte, ipp netip.AddrPort, cache *epAddrEndpointCach
ep.lastRecvUDPAny.StoreAtomic(now) ep.lastRecvUDPAny.StoreAtomic(now)
connNoted := ep.noteRecvActivity(src, now) connNoted := ep.noteRecvActivity(src, now)
if stats := c.stats.Load(); stats != nil { if stats := c.stats.Load(); stats != nil {
stats.UpdateRxPhysical(ep.nodeAddr, ipp, 1, len(b)) stats.UpdateRxPhysical(ep.nodeAddr, ipp, 1, geneveInclusivePacketLen)
} }
if src.vni.isSet() && (connNoted || looksLikeInitiationMsg(b)) { if src.vni.isSet() && (connNoted || looksLikeInitiationMsg(b)) {
// connNoted is periodic, but we also want to verify if the peer is who // connNoted is periodic, but we also want to verify if the peer is who

View File

@@ -20,6 +20,7 @@ import (
"net/http/httptest" "net/http/httptest"
"net/netip" "net/netip"
"os" "os"
"reflect"
"runtime" "runtime"
"strconv" "strconv"
"strings" "strings"
@@ -66,6 +67,7 @@ import (
"tailscale.com/types/ptr" "tailscale.com/types/ptr"
"tailscale.com/types/views" "tailscale.com/types/views"
"tailscale.com/util/cibuild" "tailscale.com/util/cibuild"
"tailscale.com/util/clientmetric"
"tailscale.com/util/eventbus" "tailscale.com/util/eventbus"
"tailscale.com/util/must" "tailscale.com/util/must"
"tailscale.com/util/racebuild" "tailscale.com/util/racebuild"
@@ -3701,3 +3703,315 @@ func TestConn_updateRelayServersSet(t *testing.T) {
}) })
} }
} }
func TestConn_receiveIP(t *testing.T) {
looksLikeNakedDisco := make([]byte, 0, len(disco.Magic)+key.DiscoPublicRawLen)
looksLikeNakedDisco = append(looksLikeNakedDisco, disco.Magic...)
looksLikeNakedDisco = looksLikeNakedDisco[:cap(looksLikeNakedDisco)]
looksLikeGeneveDisco := make([]byte, packet.GeneveFixedHeaderLength+len(looksLikeNakedDisco))
gh := packet.GeneveHeader{
Protocol: packet.GeneveProtocolDisco,
}
err := gh.Encode(looksLikeGeneveDisco)
if err != nil {
t.Fatal(err)
}
copy(looksLikeGeneveDisco[packet.GeneveFixedHeaderLength:], looksLikeNakedDisco)
looksLikeSTUNBinding := stun.Response(stun.NewTxID(), netip.MustParseAddrPort("127.0.0.1:7777"))
findMetricByName := func(name string) *clientmetric.Metric {
for _, metric := range clientmetric.Metrics() {
if metric.Name() == name {
return metric
}
}
t.Fatalf("failed to find metric with name: %v", name)
return nil
}
looksLikeNakedWireGuardInit := make([]byte, device.MessageInitiationSize)
binary.LittleEndian.PutUint32(looksLikeNakedWireGuardInit, device.MessageInitiationType)
looksLikeGeneveWireGuardInit := make([]byte, packet.GeneveFixedHeaderLength+device.MessageInitiationSize)
gh = packet.GeneveHeader{
Protocol: packet.GeneveProtocolWireGuard,
VNI: 1,
}
vni := virtualNetworkID{}
vni.set(gh.VNI)
err = gh.Encode(looksLikeGeneveWireGuardInit)
if err != nil {
t.Fatal(err)
}
copy(looksLikeGeneveWireGuardInit[packet.GeneveFixedHeaderLength:], looksLikeNakedWireGuardInit)
newPeerMapInsertableEndpoint := func(lastRecvWG mono.Time) *endpoint {
ep := &endpoint{
nodeID: 1,
publicKey: key.NewNode().Public(),
lastRecvWG: lastRecvWG,
}
ep.disco.Store(&endpointDisco{
key: key.NewDisco().Public(),
})
return ep
}
tests := []struct {
name string
// A copy of b is used as input, tests may re-use the same value.
b []byte
ipp netip.AddrPort
// cache must be non-nil, and must not be reused across tests. If
// cache.de is non-nil after receiveIP(), then we verify it is equal to
// wantEndpointType.
cache *epAddrEndpointCache
// If true, wantEndpointType is inserted into the [peerMap].
insertWantEndpointTypeInPeerMap bool
// If insertWantEndpointTypeInPeerMap is true, use this [epAddr] for it
// in the [peerMap.setNodeKeyForEpAddr] call.
peerMapEpAddr epAddr
// If [*endpoint] then we expect 'got' to be the same [*endpoint]. If
// [*lazyEndpoint] and [*lazyEndpoint.maybeEP] is non-nil, we expect
// got.maybeEP to also be non-nil. Must not be reused across tests.
wantEndpointType wgconn.Endpoint
wantSize int
wantIsGeneveEncap bool
wantOk bool
wantMetricInc *clientmetric.Metric
wantNoteRecvActivityCalled bool
}{
{
name: "naked disco",
b: looksLikeNakedDisco,
ipp: netip.MustParseAddrPort("127.0.0.1:7777"),
cache: &epAddrEndpointCache{},
wantEndpointType: nil,
wantSize: 0,
wantIsGeneveEncap: false,
wantOk: false,
wantMetricInc: metricRecvDiscoBadPeer,
wantNoteRecvActivityCalled: false,
},
{
name: "geneve encap disco",
b: looksLikeGeneveDisco,
ipp: netip.MustParseAddrPort("127.0.0.1:7777"),
cache: &epAddrEndpointCache{},
wantEndpointType: nil,
wantSize: 0,
wantIsGeneveEncap: false,
wantOk: false,
wantMetricInc: metricRecvDiscoBadPeer,
wantNoteRecvActivityCalled: false,
},
{
name: "STUN binding",
b: looksLikeSTUNBinding,
ipp: netip.MustParseAddrPort("127.0.0.1:7777"),
cache: &epAddrEndpointCache{},
wantEndpointType: nil,
wantSize: 0,
wantIsGeneveEncap: false,
wantOk: false,
wantMetricInc: findMetricByName("netcheck_stun_recv_ipv4"),
wantNoteRecvActivityCalled: false,
},
{
name: "naked WireGuard init lazyEndpoint empty peerMap",
b: looksLikeNakedWireGuardInit,
ipp: netip.MustParseAddrPort("127.0.0.1:7777"),
cache: &epAddrEndpointCache{},
wantEndpointType: &lazyEndpoint{},
wantSize: len(looksLikeNakedWireGuardInit),
wantIsGeneveEncap: false,
wantOk: true,
wantMetricInc: nil,
wantNoteRecvActivityCalled: false,
},
{
name: "naked WireGuard init endpoint matching peerMap entry",
b: looksLikeNakedWireGuardInit,
ipp: netip.MustParseAddrPort("127.0.0.1:7777"),
cache: &epAddrEndpointCache{},
insertWantEndpointTypeInPeerMap: true,
peerMapEpAddr: epAddr{ap: netip.MustParseAddrPort("127.0.0.1:7777")},
wantEndpointType: newPeerMapInsertableEndpoint(0),
wantSize: len(looksLikeNakedWireGuardInit),
wantIsGeneveEncap: false,
wantOk: true,
wantMetricInc: nil,
wantNoteRecvActivityCalled: true,
},
{
name: "geneve WireGuard init lazyEndpoint empty peerMap",
b: looksLikeGeneveWireGuardInit,
ipp: netip.MustParseAddrPort("127.0.0.1:7777"),
cache: &epAddrEndpointCache{},
wantEndpointType: &lazyEndpoint{},
wantSize: len(looksLikeGeneveWireGuardInit) - packet.GeneveFixedHeaderLength,
wantIsGeneveEncap: true,
wantOk: true,
wantMetricInc: nil,
wantNoteRecvActivityCalled: false,
},
{
name: "geneve WireGuard init lazyEndpoint matching peerMap activity noted",
b: looksLikeGeneveWireGuardInit,
ipp: netip.MustParseAddrPort("127.0.0.1:7777"),
cache: &epAddrEndpointCache{},
insertWantEndpointTypeInPeerMap: true,
peerMapEpAddr: epAddr{ap: netip.MustParseAddrPort("127.0.0.1:7777"), vni: vni},
wantEndpointType: &lazyEndpoint{
maybeEP: newPeerMapInsertableEndpoint(0),
},
wantSize: len(looksLikeGeneveWireGuardInit) - packet.GeneveFixedHeaderLength,
wantIsGeneveEncap: true,
wantOk: true,
wantMetricInc: nil,
wantNoteRecvActivityCalled: true,
},
{
name: "geneve WireGuard init lazyEndpoint matching peerMap no activity noted",
b: looksLikeGeneveWireGuardInit,
ipp: netip.MustParseAddrPort("127.0.0.1:7777"),
cache: &epAddrEndpointCache{},
insertWantEndpointTypeInPeerMap: true,
peerMapEpAddr: epAddr{ap: netip.MustParseAddrPort("127.0.0.1:7777"), vni: vni},
wantEndpointType: &lazyEndpoint{
maybeEP: newPeerMapInsertableEndpoint(mono.Now().Add(time.Hour * 24)),
},
wantSize: len(looksLikeGeneveWireGuardInit) - packet.GeneveFixedHeaderLength,
wantIsGeneveEncap: true,
wantOk: true,
wantMetricInc: nil,
wantNoteRecvActivityCalled: false,
},
// TODO(jwhited): verify cache.de is used when conditions permit
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
noteRecvActivityCalled := false
metricBefore := int64(0)
if tt.wantMetricInc != nil {
metricBefore = tt.wantMetricInc.Value()
}
// Init Conn.
c := &Conn{
privateKey: key.NewNode(),
netChecker: &netcheck.Client{},
peerMap: newPeerMap(),
}
c.havePrivateKey.Store(true)
c.noteRecvActivity = func(public key.NodePublic) {
noteRecvActivityCalled = true
}
c.SetStatistics(connstats.NewStatistics(0, 0, nil))
if tt.insertWantEndpointTypeInPeerMap {
var insertEPIntoPeerMap *endpoint
switch ep := tt.wantEndpointType.(type) {
case *endpoint:
insertEPIntoPeerMap = ep
case *lazyEndpoint:
insertEPIntoPeerMap = ep.maybeEP
default:
t.Fatal("unexpected tt.wantEndpointType concrete type")
}
insertEPIntoPeerMap.c = c
c.peerMap.upsertEndpoint(insertEPIntoPeerMap, key.DiscoPublic{})
c.peerMap.setNodeKeyForEpAddr(tt.peerMapEpAddr, insertEPIntoPeerMap.publicKey)
}
// Allow the same input packet to be used across tests, receiveIP()
// may mutate.
inputPacket := make([]byte, len(tt.b))
copy(inputPacket, tt.b)
got, gotSize, gotIsGeneveEncap, gotOk := c.receiveIP(inputPacket, tt.ipp, tt.cache)
if (tt.wantEndpointType == nil) != (got == nil) {
t.Errorf("receiveIP() (tt.wantEndpointType == nil): %v != (got == nil): %v", tt.wantEndpointType == nil, got == nil)
}
if tt.wantEndpointType != nil && reflect.TypeOf(got).String() != reflect.TypeOf(tt.wantEndpointType).String() {
t.Errorf("receiveIP() got = %v, want %v", reflect.TypeOf(got).String(), reflect.TypeOf(tt.wantEndpointType).String())
} else {
switch ep := tt.wantEndpointType.(type) {
case *endpoint:
if ep != got.(*endpoint) {
t.Errorf("receiveIP() want [*endpoint]: %p != got [*endpoint]: %p", ep, got)
}
case *lazyEndpoint:
if ep.maybeEP != nil && ep.maybeEP != got.(*lazyEndpoint).maybeEP {
t.Errorf("receiveIP() want [*lazyEndpoint.maybeEP]: %p != got [*lazyEndpoint.maybeEP] %p", ep, got)
}
}
}
if gotSize != tt.wantSize {
t.Errorf("receiveIP() gotSize = %v, want %v", gotSize, tt.wantSize)
}
if gotIsGeneveEncap != tt.wantIsGeneveEncap {
t.Errorf("receiveIP() gotIsGeneveEncap = %v, want %v", gotIsGeneveEncap, tt.wantIsGeneveEncap)
}
if gotOk != tt.wantOk {
t.Errorf("receiveIP() gotOk = %v, want %v", gotOk, tt.wantOk)
}
if tt.wantMetricInc != nil && tt.wantMetricInc.Value() != metricBefore+1 {
t.Errorf("receiveIP() metric %v not incremented", tt.wantMetricInc.Name())
}
if tt.wantNoteRecvActivityCalled != noteRecvActivityCalled {
t.Errorf("receiveIP() noteRecvActivityCalled = %v, want %v", noteRecvActivityCalled, tt.wantNoteRecvActivityCalled)
}
if tt.cache.de != nil {
switch ep := got.(type) {
case *endpoint:
if tt.cache.de != ep {
t.Errorf("receiveIP() cache populated with [*endpoint] %p, want %p", tt.cache.de, ep)
}
case *lazyEndpoint:
if tt.cache.de != ep.maybeEP {
t.Errorf("receiveIP() cache populated with [*endpoint] %p, want (lazyEndpoint.maybeEP) %p", tt.cache.de, ep.maybeEP)
}
default:
t.Fatal("receiveIP() unexpected [conn.Endpoint] type")
}
}
// Verify physical rx stats
stats := c.stats.Load()
_, gotPhy := stats.TestExtract()
wantNonzeroRxStats := false
switch ep := tt.wantEndpointType.(type) {
case *lazyEndpoint:
if ep.maybeEP != nil {
wantNonzeroRxStats = true
}
case *endpoint:
wantNonzeroRxStats = true
}
if tt.wantOk && wantNonzeroRxStats {
wantRxBytes := uint64(tt.wantSize)
if tt.wantIsGeneveEncap {
wantRxBytes += packet.GeneveFixedHeaderLength
}
wantPhy := map[netlogtype.Connection]netlogtype.Counts{
{Dst: tt.ipp}: {
RxPackets: 1,
RxBytes: wantRxBytes,
},
}
if !reflect.DeepEqual(gotPhy, wantPhy) {
t.Errorf("receiveIP() got physical conn stats = %v, want %v", gotPhy, wantPhy)
}
} else {
if len(gotPhy) != 0 {
t.Errorf("receiveIP() unexpected nonzero physical count stats: %+v", gotPhy)
}
}
})
}
}