From 080387558c6d7654ac6d7a694edc73c32b10b2cb Mon Sep 17 00:00:00 2001 From: Jordan Whited Date: Wed, 30 Apr 2025 13:31:35 -0700 Subject: [PATCH] wgengine/magicsock: start to make disco reception Geneve aware (#15832) Updates tailscale/corp#27502 Signed-off-by: Jordan Whited --- wgengine/magicsock/magicsock.go | 64 +++++++++-- wgengine/magicsock/magicsock_test.go | 162 +++++++++++++++++++++++++++ 2 files changed, 215 insertions(+), 11 deletions(-) diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 31bf66b2b..28ad06d2a 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -9,6 +9,7 @@ import ( "bufio" "bytes" "context" + "encoding/binary" "errors" "expvar" "fmt" @@ -1707,6 +1708,45 @@ const ( discoRXPathRawSocket discoRXPath = "raw socket" ) +const discoHeaderLen = len(disco.Magic) + key.DiscoPublicRawLen + +// isDiscoMaybeGeneve reports whether msg is a Tailscale Disco protocol +// message, and if true, whether it is encapsulated by a Geneve header. +// +// isGeneveEncap is only relevant when isDiscoMsg is true. +// +// Naked Disco, Geneve followed by Disco, and naked WireGuard can be confidently +// distinguished based on the following: +// 1. [disco.Magic] is sufficiently non-overlapping with a Geneve protocol +// field value of [packet.GeneveProtocolDisco]. +// 2. [disco.Magic] is sufficiently non-overlapping with the first 4 bytes of +// a WireGuard packet. +// 3. [packet.GeneveHeader] with a Geneve protocol field value of +// [packet.GeneveProtocolDisco] is sufficiently non-overlapping with the +// first 4 bytes of a WireGuard packet. +func isDiscoMaybeGeneve(msg []byte) (isDiscoMsg bool, isGeneveEncap bool) { + if len(msg) < discoHeaderLen { + return false, false + } + if string(msg[:len(disco.Magic)]) == disco.Magic { + return true, false + } + if len(msg) < packet.GeneveFixedHeaderLength+discoHeaderLen { + return false, false + } + if msg[0]&0xC0 != 0 || // version bits that we always transmit as 0s + msg[1]&0x3F != 0 || // reserved bits that we always transmit as 0s + binary.BigEndian.Uint16(msg[2:4]) != packet.GeneveProtocolDisco || + msg[7] != 0 { // reserved byte that we always transmit as 0 + return false, false + } + msg = msg[packet.GeneveFixedHeaderLength:] + if string(msg[:len(disco.Magic)]) == disco.Magic { + return true, true + } + return false, false +} + // handleDiscoMessage handles a discovery message and reports whether // msg was a Tailscale inter-node discovery message. // @@ -1722,18 +1762,16 @@ const ( // it was received from at the DERP layer. derpNodeSrc is zero when received // over UDP. func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc key.NodePublic, via discoRXPath) (isDiscoMsg bool) { - const headerLen = len(disco.Magic) + key.DiscoPublicRawLen - if len(msg) < headerLen || string(msg[:len(disco.Magic)]) != disco.Magic { - return false + isDiscoMsg, isGeneveEncap := isDiscoMaybeGeneve(msg) + if !isDiscoMsg { + return + } + if isGeneveEncap { + // TODO(jwhited): decode Geneve header + msg = msg[packet.GeneveFixedHeaderLength:] } - // If the first four parts are the prefix of disco.Magic - // (0x5453f09f) then it's definitely not a valid WireGuard - // packet (which starts with little-endian uint32 1, 2, 3, 4). - // Use naked returns for all following paths. - isDiscoMsg = true - - sender := key.DiscoPublicFromRaw32(mem.B(msg[len(disco.Magic):headerLen])) + sender := key.DiscoPublicFromRaw32(mem.B(msg[len(disco.Magic):discoHeaderLen])) c.mu.Lock() defer c.mu.Unlock() @@ -1751,6 +1789,10 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc ke } if !c.peerMap.knownPeerDiscoKey(sender) { + // Geneve encapsulated disco used for udp relay handshakes are not known + // "peer" keys as they are dynamically discovered by UDP relay endpoint + // allocation or [disco.CallMeMaybeVia] reception. + // TODO(jwhited): handle relay handshake messsages instead of early return metricRecvDiscoBadPeer.Add(1) if debugDisco() { c.logf("magicsock: disco: ignoring disco-looking frame, don't know of key %v", sender.ShortString()) @@ -1774,7 +1816,7 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc ke di := c.discoInfoLocked(sender) - sealedBox := msg[headerLen:] + sealedBox := msg[discoHeaderLen:] payload, ok := di.sharedKey.Open(sealedBox) if !ok { // This might be have been intended for a previous diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index f50f21f56..1a899ea22 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -3155,3 +3155,165 @@ func TestNetworkDownSendErrors(t *testing.T) { t.Errorf("expected NetworkDown to increment packet dropped metric; got %q", resp.Body.String()) } } + +func Test_isDiscoMaybeGeneve(t *testing.T) { + discoPub := key.DiscoPublicFromRaw32(mem.B([]byte{1: 1, 30: 30, 31: 31})) + nakedDisco := make([]byte, 0, 512) + nakedDisco = append(nakedDisco, disco.Magic...) + nakedDisco = discoPub.AppendTo(nakedDisco) + + geneveEncapDisco := make([]byte, packet.GeneveFixedHeaderLength+len(nakedDisco)) + gh := packet.GeneveHeader{ + Version: 0, + Protocol: packet.GeneveProtocolDisco, + VNI: 1, + Control: true, + } + err := gh.Encode(geneveEncapDisco) + if err != nil { + t.Fatal(err) + } + copy(geneveEncapDisco[packet.GeneveFixedHeaderLength:], nakedDisco) + + nakedWireGuardInitiation := make([]byte, len(geneveEncapDisco)) + binary.LittleEndian.PutUint32(nakedWireGuardInitiation, device.MessageInitiationType) + nakedWireGuardResponse := make([]byte, len(geneveEncapDisco)) + binary.LittleEndian.PutUint32(nakedWireGuardResponse, device.MessageResponseType) + nakedWireGuardCookieReply := make([]byte, len(geneveEncapDisco)) + binary.LittleEndian.PutUint32(nakedWireGuardCookieReply, device.MessageCookieReplyType) + nakedWireGuardTransport := make([]byte, len(geneveEncapDisco)) + binary.LittleEndian.PutUint32(nakedWireGuardTransport, device.MessageTransportType) + + geneveEncapWireGuard := make([]byte, packet.GeneveFixedHeaderLength+len(nakedWireGuardInitiation)) + gh = packet.GeneveHeader{ + Version: 0, + Protocol: packet.GeneveProtocolWireGuard, + VNI: 1, + Control: true, + } + err = gh.Encode(geneveEncapWireGuard) + if err != nil { + t.Fatal(err) + } + copy(geneveEncapWireGuard[packet.GeneveFixedHeaderLength:], nakedWireGuardInitiation) + + geneveEncapDiscoNonZeroGeneveVersion := make([]byte, packet.GeneveFixedHeaderLength+len(nakedDisco)) + gh = packet.GeneveHeader{ + Version: 1, + Protocol: packet.GeneveProtocolDisco, + VNI: 1, + Control: true, + } + err = gh.Encode(geneveEncapDiscoNonZeroGeneveVersion) + if err != nil { + t.Fatal(err) + } + copy(geneveEncapDiscoNonZeroGeneveVersion[packet.GeneveFixedHeaderLength:], nakedDisco) + + geneveEncapDiscoNonZeroGeneveReservedBits := make([]byte, packet.GeneveFixedHeaderLength+len(nakedDisco)) + gh = packet.GeneveHeader{ + Version: 0, + Protocol: packet.GeneveProtocolDisco, + VNI: 1, + Control: true, + } + err = gh.Encode(geneveEncapDiscoNonZeroGeneveReservedBits) + if err != nil { + t.Fatal(err) + } + geneveEncapDiscoNonZeroGeneveReservedBits[1] |= 0x3F + copy(geneveEncapDiscoNonZeroGeneveReservedBits[packet.GeneveFixedHeaderLength:], nakedDisco) + + geneveEncapDiscoNonZeroGeneveVNILSB := make([]byte, packet.GeneveFixedHeaderLength+len(nakedDisco)) + gh = packet.GeneveHeader{ + Version: 0, + Protocol: packet.GeneveProtocolDisco, + VNI: 1, + Control: true, + } + err = gh.Encode(geneveEncapDiscoNonZeroGeneveVNILSB) + if err != nil { + t.Fatal(err) + } + geneveEncapDiscoNonZeroGeneveVNILSB[7] |= 0xFF + copy(geneveEncapDiscoNonZeroGeneveVNILSB[packet.GeneveFixedHeaderLength:], nakedDisco) + + tests := []struct { + name string + msg []byte + wantIsDiscoMsg bool + wantIsGeneveEncap bool + }{ + { + name: "naked disco", + msg: nakedDisco, + wantIsDiscoMsg: true, + wantIsGeneveEncap: false, + }, + { + name: "geneve encap disco", + msg: geneveEncapDisco, + wantIsDiscoMsg: true, + wantIsGeneveEncap: true, + }, + { + name: "geneve encap disco nonzero geneve version", + msg: geneveEncapDiscoNonZeroGeneveVersion, + wantIsDiscoMsg: false, + wantIsGeneveEncap: false, + }, + { + name: "geneve encap disco nonzero geneve reserved bits", + msg: geneveEncapDiscoNonZeroGeneveReservedBits, + wantIsDiscoMsg: false, + wantIsGeneveEncap: false, + }, + { + name: "geneve encap disco nonzero geneve vni lsb", + msg: geneveEncapDiscoNonZeroGeneveVNILSB, + wantIsDiscoMsg: false, + wantIsGeneveEncap: false, + }, + { + name: "geneve encap wireguard", + msg: geneveEncapWireGuard, + wantIsDiscoMsg: false, + wantIsGeneveEncap: false, + }, + { + name: "naked WireGuard Initiation type", + msg: nakedWireGuardInitiation, + wantIsDiscoMsg: false, + wantIsGeneveEncap: false, + }, + { + name: "naked WireGuard Response type", + msg: nakedWireGuardResponse, + wantIsDiscoMsg: false, + wantIsGeneveEncap: false, + }, + { + name: "naked WireGuard Cookie Reply type", + msg: nakedWireGuardCookieReply, + wantIsDiscoMsg: false, + wantIsGeneveEncap: false, + }, + { + name: "naked WireGuard Transport type", + msg: nakedWireGuardTransport, + wantIsDiscoMsg: false, + wantIsGeneveEncap: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotIsDiscoMsg, gotIsGeneveEncap := isDiscoMaybeGeneve(tt.msg) + if gotIsDiscoMsg != tt.wantIsDiscoMsg { + t.Errorf("isDiscoMaybeGeneve() gotIsDiscoMsg = %v, want %v", gotIsDiscoMsg, tt.wantIsDiscoMsg) + } + if gotIsGeneveEncap != tt.wantIsGeneveEncap { + t.Errorf("isDiscoMaybeGeneve() gotIsGeneveEncap = %v, want %v", gotIsGeneveEncap, tt.wantIsGeneveEncap) + } + }) + } +}