From 908f20e0a506f9fe0c3f6479bc6b7c017cab27a1 Mon Sep 17 00:00:00 2001 From: Jordan Whited Date: Wed, 6 Aug 2025 09:35:25 -0700 Subject: [PATCH] wgengine/magicsock: add receiveIP() unit tests (#16781) One of these tests highlighted a Geneve encap bug, which is also fixed in this commit. looksLikeInitMsg was passed a packet post Geneve header stripping with slice offsets that had not been updated to account for the stripping. Updates tailscale/corp#30903 Signed-off-by: Jordan Whited --- wgengine/magicsock/magicsock.go | 6 +- wgengine/magicsock/magicsock_test.go | 314 +++++++++++++++++++++++++++ 2 files changed, 319 insertions(+), 1 deletion(-) diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index c99d1b68f..04d4bbbde 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -1823,6 +1823,9 @@ func (c *Conn) receiveIP(b []byte, ipp netip.AddrPort, cache *epAddrEndpointCach return nil, 0, false, false } + // geneveInclusivePacketLen holds the packet length prior to any potential + // Geneve header stripping. + geneveInclusivePacketLen := len(b) if src.vni.isSet() { // Strip away the Geneve header before returning the packet to // wireguard-go. @@ -1831,6 +1834,7 @@ func (c *Conn) receiveIP(b []byte, ipp netip.AddrPort, cache *epAddrEndpointCach // to support returning start offset in order to get rid of this memmove perf // penalty. size = copy(b, b[packet.GeneveFixedHeaderLength:]) + b = b[:size] } if cache.epAddr == src && cache.de != nil && cache.gen == cache.de.numStopAndReset() { @@ -1859,7 +1863,7 @@ func (c *Conn) receiveIP(b []byte, ipp netip.AddrPort, cache *epAddrEndpointCach ep.lastRecvUDPAny.StoreAtomic(now) connNoted := ep.noteRecvActivity(src, now) if stats := c.stats.Load(); stats != nil { - stats.UpdateRxPhysical(ep.nodeAddr, ipp, 1, len(b)) + stats.UpdateRxPhysical(ep.nodeAddr, ipp, 1, geneveInclusivePacketLen) } if src.vni.isSet() && (connNoted || looksLikeInitiationMsg(b)) { // connNoted is periodic, but we also want to verify if the peer is who diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index 0d1ac9dfd..685fff4da 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -20,6 +20,7 @@ import ( "net/http/httptest" "net/netip" "os" + "reflect" "runtime" "strconv" "strings" @@ -66,6 +67,7 @@ import ( "tailscale.com/types/ptr" "tailscale.com/types/views" "tailscale.com/util/cibuild" + "tailscale.com/util/clientmetric" "tailscale.com/util/eventbus" "tailscale.com/util/must" "tailscale.com/util/racebuild" @@ -3701,3 +3703,315 @@ func TestConn_updateRelayServersSet(t *testing.T) { }) } } + +func TestConn_receiveIP(t *testing.T) { + looksLikeNakedDisco := make([]byte, 0, len(disco.Magic)+key.DiscoPublicRawLen) + looksLikeNakedDisco = append(looksLikeNakedDisco, disco.Magic...) + looksLikeNakedDisco = looksLikeNakedDisco[:cap(looksLikeNakedDisco)] + + looksLikeGeneveDisco := make([]byte, packet.GeneveFixedHeaderLength+len(looksLikeNakedDisco)) + gh := packet.GeneveHeader{ + Protocol: packet.GeneveProtocolDisco, + } + err := gh.Encode(looksLikeGeneveDisco) + if err != nil { + t.Fatal(err) + } + copy(looksLikeGeneveDisco[packet.GeneveFixedHeaderLength:], looksLikeNakedDisco) + + looksLikeSTUNBinding := stun.Response(stun.NewTxID(), netip.MustParseAddrPort("127.0.0.1:7777")) + + findMetricByName := func(name string) *clientmetric.Metric { + for _, metric := range clientmetric.Metrics() { + if metric.Name() == name { + return metric + } + } + t.Fatalf("failed to find metric with name: %v", name) + return nil + } + + looksLikeNakedWireGuardInit := make([]byte, device.MessageInitiationSize) + binary.LittleEndian.PutUint32(looksLikeNakedWireGuardInit, device.MessageInitiationType) + + looksLikeGeneveWireGuardInit := make([]byte, packet.GeneveFixedHeaderLength+device.MessageInitiationSize) + gh = packet.GeneveHeader{ + Protocol: packet.GeneveProtocolWireGuard, + VNI: 1, + } + vni := virtualNetworkID{} + vni.set(gh.VNI) + err = gh.Encode(looksLikeGeneveWireGuardInit) + if err != nil { + t.Fatal(err) + } + copy(looksLikeGeneveWireGuardInit[packet.GeneveFixedHeaderLength:], looksLikeNakedWireGuardInit) + + newPeerMapInsertableEndpoint := func(lastRecvWG mono.Time) *endpoint { + ep := &endpoint{ + nodeID: 1, + publicKey: key.NewNode().Public(), + lastRecvWG: lastRecvWG, + } + ep.disco.Store(&endpointDisco{ + key: key.NewDisco().Public(), + }) + return ep + } + + tests := []struct { + name string + // A copy of b is used as input, tests may re-use the same value. + b []byte + ipp netip.AddrPort + // cache must be non-nil, and must not be reused across tests. If + // cache.de is non-nil after receiveIP(), then we verify it is equal to + // wantEndpointType. + cache *epAddrEndpointCache + // If true, wantEndpointType is inserted into the [peerMap]. + insertWantEndpointTypeInPeerMap bool + // If insertWantEndpointTypeInPeerMap is true, use this [epAddr] for it + // in the [peerMap.setNodeKeyForEpAddr] call. + peerMapEpAddr epAddr + // If [*endpoint] then we expect 'got' to be the same [*endpoint]. If + // [*lazyEndpoint] and [*lazyEndpoint.maybeEP] is non-nil, we expect + // got.maybeEP to also be non-nil. Must not be reused across tests. + wantEndpointType wgconn.Endpoint + wantSize int + wantIsGeneveEncap bool + wantOk bool + wantMetricInc *clientmetric.Metric + wantNoteRecvActivityCalled bool + }{ + { + name: "naked disco", + b: looksLikeNakedDisco, + ipp: netip.MustParseAddrPort("127.0.0.1:7777"), + cache: &epAddrEndpointCache{}, + wantEndpointType: nil, + wantSize: 0, + wantIsGeneveEncap: false, + wantOk: false, + wantMetricInc: metricRecvDiscoBadPeer, + wantNoteRecvActivityCalled: false, + }, + { + name: "geneve encap disco", + b: looksLikeGeneveDisco, + ipp: netip.MustParseAddrPort("127.0.0.1:7777"), + cache: &epAddrEndpointCache{}, + wantEndpointType: nil, + wantSize: 0, + wantIsGeneveEncap: false, + wantOk: false, + wantMetricInc: metricRecvDiscoBadPeer, + wantNoteRecvActivityCalled: false, + }, + { + name: "STUN binding", + b: looksLikeSTUNBinding, + ipp: netip.MustParseAddrPort("127.0.0.1:7777"), + cache: &epAddrEndpointCache{}, + wantEndpointType: nil, + wantSize: 0, + wantIsGeneveEncap: false, + wantOk: false, + wantMetricInc: findMetricByName("netcheck_stun_recv_ipv4"), + wantNoteRecvActivityCalled: false, + }, + { + name: "naked WireGuard init lazyEndpoint empty peerMap", + b: looksLikeNakedWireGuardInit, + ipp: netip.MustParseAddrPort("127.0.0.1:7777"), + cache: &epAddrEndpointCache{}, + wantEndpointType: &lazyEndpoint{}, + wantSize: len(looksLikeNakedWireGuardInit), + wantIsGeneveEncap: false, + wantOk: true, + wantMetricInc: nil, + wantNoteRecvActivityCalled: false, + }, + { + name: "naked WireGuard init endpoint matching peerMap entry", + b: looksLikeNakedWireGuardInit, + ipp: netip.MustParseAddrPort("127.0.0.1:7777"), + cache: &epAddrEndpointCache{}, + insertWantEndpointTypeInPeerMap: true, + peerMapEpAddr: epAddr{ap: netip.MustParseAddrPort("127.0.0.1:7777")}, + wantEndpointType: newPeerMapInsertableEndpoint(0), + wantSize: len(looksLikeNakedWireGuardInit), + wantIsGeneveEncap: false, + wantOk: true, + wantMetricInc: nil, + wantNoteRecvActivityCalled: true, + }, + { + name: "geneve WireGuard init lazyEndpoint empty peerMap", + b: looksLikeGeneveWireGuardInit, + ipp: netip.MustParseAddrPort("127.0.0.1:7777"), + cache: &epAddrEndpointCache{}, + wantEndpointType: &lazyEndpoint{}, + wantSize: len(looksLikeGeneveWireGuardInit) - packet.GeneveFixedHeaderLength, + wantIsGeneveEncap: true, + wantOk: true, + wantMetricInc: nil, + wantNoteRecvActivityCalled: false, + }, + { + name: "geneve WireGuard init lazyEndpoint matching peerMap activity noted", + b: looksLikeGeneveWireGuardInit, + ipp: netip.MustParseAddrPort("127.0.0.1:7777"), + cache: &epAddrEndpointCache{}, + insertWantEndpointTypeInPeerMap: true, + peerMapEpAddr: epAddr{ap: netip.MustParseAddrPort("127.0.0.1:7777"), vni: vni}, + wantEndpointType: &lazyEndpoint{ + maybeEP: newPeerMapInsertableEndpoint(0), + }, + wantSize: len(looksLikeGeneveWireGuardInit) - packet.GeneveFixedHeaderLength, + wantIsGeneveEncap: true, + wantOk: true, + wantMetricInc: nil, + wantNoteRecvActivityCalled: true, + }, + { + name: "geneve WireGuard init lazyEndpoint matching peerMap no activity noted", + b: looksLikeGeneveWireGuardInit, + ipp: netip.MustParseAddrPort("127.0.0.1:7777"), + cache: &epAddrEndpointCache{}, + insertWantEndpointTypeInPeerMap: true, + peerMapEpAddr: epAddr{ap: netip.MustParseAddrPort("127.0.0.1:7777"), vni: vni}, + wantEndpointType: &lazyEndpoint{ + maybeEP: newPeerMapInsertableEndpoint(mono.Now().Add(time.Hour * 24)), + }, + wantSize: len(looksLikeGeneveWireGuardInit) - packet.GeneveFixedHeaderLength, + wantIsGeneveEncap: true, + wantOk: true, + wantMetricInc: nil, + wantNoteRecvActivityCalled: false, + }, + // TODO(jwhited): verify cache.de is used when conditions permit + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + noteRecvActivityCalled := false + metricBefore := int64(0) + if tt.wantMetricInc != nil { + metricBefore = tt.wantMetricInc.Value() + } + + // Init Conn. + c := &Conn{ + privateKey: key.NewNode(), + netChecker: &netcheck.Client{}, + peerMap: newPeerMap(), + } + c.havePrivateKey.Store(true) + c.noteRecvActivity = func(public key.NodePublic) { + noteRecvActivityCalled = true + } + c.SetStatistics(connstats.NewStatistics(0, 0, nil)) + + if tt.insertWantEndpointTypeInPeerMap { + var insertEPIntoPeerMap *endpoint + switch ep := tt.wantEndpointType.(type) { + case *endpoint: + insertEPIntoPeerMap = ep + case *lazyEndpoint: + insertEPIntoPeerMap = ep.maybeEP + default: + t.Fatal("unexpected tt.wantEndpointType concrete type") + } + insertEPIntoPeerMap.c = c + c.peerMap.upsertEndpoint(insertEPIntoPeerMap, key.DiscoPublic{}) + c.peerMap.setNodeKeyForEpAddr(tt.peerMapEpAddr, insertEPIntoPeerMap.publicKey) + } + + // Allow the same input packet to be used across tests, receiveIP() + // may mutate. + inputPacket := make([]byte, len(tt.b)) + copy(inputPacket, tt.b) + + got, gotSize, gotIsGeneveEncap, gotOk := c.receiveIP(inputPacket, tt.ipp, tt.cache) + if (tt.wantEndpointType == nil) != (got == nil) { + t.Errorf("receiveIP() (tt.wantEndpointType == nil): %v != (got == nil): %v", tt.wantEndpointType == nil, got == nil) + } + if tt.wantEndpointType != nil && reflect.TypeOf(got).String() != reflect.TypeOf(tt.wantEndpointType).String() { + t.Errorf("receiveIP() got = %v, want %v", reflect.TypeOf(got).String(), reflect.TypeOf(tt.wantEndpointType).String()) + } else { + switch ep := tt.wantEndpointType.(type) { + case *endpoint: + if ep != got.(*endpoint) { + t.Errorf("receiveIP() want [*endpoint]: %p != got [*endpoint]: %p", ep, got) + } + case *lazyEndpoint: + if ep.maybeEP != nil && ep.maybeEP != got.(*lazyEndpoint).maybeEP { + t.Errorf("receiveIP() want [*lazyEndpoint.maybeEP]: %p != got [*lazyEndpoint.maybeEP] %p", ep, got) + } + } + } + + if gotSize != tt.wantSize { + t.Errorf("receiveIP() gotSize = %v, want %v", gotSize, tt.wantSize) + } + if gotIsGeneveEncap != tt.wantIsGeneveEncap { + t.Errorf("receiveIP() gotIsGeneveEncap = %v, want %v", gotIsGeneveEncap, tt.wantIsGeneveEncap) + } + if gotOk != tt.wantOk { + t.Errorf("receiveIP() gotOk = %v, want %v", gotOk, tt.wantOk) + } + if tt.wantMetricInc != nil && tt.wantMetricInc.Value() != metricBefore+1 { + t.Errorf("receiveIP() metric %v not incremented", tt.wantMetricInc.Name()) + } + if tt.wantNoteRecvActivityCalled != noteRecvActivityCalled { + t.Errorf("receiveIP() noteRecvActivityCalled = %v, want %v", noteRecvActivityCalled, tt.wantNoteRecvActivityCalled) + } + + if tt.cache.de != nil { + switch ep := got.(type) { + case *endpoint: + if tt.cache.de != ep { + t.Errorf("receiveIP() cache populated with [*endpoint] %p, want %p", tt.cache.de, ep) + } + case *lazyEndpoint: + if tt.cache.de != ep.maybeEP { + t.Errorf("receiveIP() cache populated with [*endpoint] %p, want (lazyEndpoint.maybeEP) %p", tt.cache.de, ep.maybeEP) + } + default: + t.Fatal("receiveIP() unexpected [conn.Endpoint] type") + } + } + + // Verify physical rx stats + stats := c.stats.Load() + _, gotPhy := stats.TestExtract() + wantNonzeroRxStats := false + switch ep := tt.wantEndpointType.(type) { + case *lazyEndpoint: + if ep.maybeEP != nil { + wantNonzeroRxStats = true + } + case *endpoint: + wantNonzeroRxStats = true + } + if tt.wantOk && wantNonzeroRxStats { + wantRxBytes := uint64(tt.wantSize) + if tt.wantIsGeneveEncap { + wantRxBytes += packet.GeneveFixedHeaderLength + } + wantPhy := map[netlogtype.Connection]netlogtype.Counts{ + {Dst: tt.ipp}: { + RxPackets: 1, + RxBytes: wantRxBytes, + }, + } + if !reflect.DeepEqual(gotPhy, wantPhy) { + t.Errorf("receiveIP() got physical conn stats = %v, want %v", gotPhy, wantPhy) + } + } else { + if len(gotPhy) != 0 { + t.Errorf("receiveIP() unexpected nonzero physical count stats: %+v", gotPhy) + } + } + }) + } +}