fixup! wgengine/magicsock: actually use AF_PACKET socket for raw disco

Signed-off-by: Andrew Dunham <andrew@du.nham.ca>
Change-Id: I2c71d7598b9e30df717329db7dc17cb4ad3f05f6
This commit is contained in:
Andrew Dunham 2024-08-15 13:16:40 -04:00
parent 35b91cb66a
commit 7dde340194
2 changed files with 158 additions and 144 deletions

View File

@ -6,6 +6,7 @@ package magicsock
import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
"io"
@ -24,7 +25,6 @@ import (
"tailscale.com/disco"
"tailscale.com/envknob"
"tailscale.com/net/netns"
"tailscale.com/net/packet"
"tailscale.com/types/ipproto"
"tailscale.com/types/key"
"tailscale.com/types/logger"
@ -51,6 +51,14 @@ var (
// receives the entire IPv4 packet, but not the Ethernet
// header.
// Double-check that this is a UDP packet; we shouldn't be
// seeing anything else given how we create our AF_PACKET
// socket, but an extra check here is cheap, and matches the
// check that we do in the IPv6 path.
bpf.LoadAbsolute{Off: 9, Size: 1},
bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(ipproto.UDP), SkipTrue: 1, SkipFalse: 0},
bpf.RetConstant{Val: 0x0},
// Disco packets are so small they should never get
// fragmented, and we don't want to handle reassembly.
bpf.LoadAbsolute{Off: 6, Size: 2},
@ -235,7 +243,6 @@ func (c *Conn) listenRawDisco(family string) (io.Closer, error) {
var (
ctx = context.Background()
buf [1500]byte
pkt packet.Parsed
)
for {
n, _, err := sock.Recvfrom(ctx, buf[:], 0)
@ -244,10 +251,11 @@ func (c *Conn) listenRawDisco(family string) (io.Closer, error) {
return nil, fmt.Errorf("reading during raw disco self-test: %w", err)
}
if !decodeDiscoPacket(&pkt, c.discoLogf, buf[:n], family == "ip6") {
_ /* src */, _ /* dst */, payload := parseUDPPacket(buf[:n], family == "ip6")
if payload == nil {
continue
}
if payload := pkt.Payload(); !bytes.Equal(payload, testDiscoPacket) {
if !bytes.Equal(payload, testDiscoPacket) {
c.discoLogf("listenRawDisco: self-test: received mismatched UDP packet of %d bytes", len(payload))
continue
}
@ -260,50 +268,60 @@ func (c *Conn) listenRawDisco(family string) (io.Closer, error) {
return sock, nil
}
// decodeDiscoPacket decodes a disco packet from buf, using pkt as storage for
// the parsed packet. It returns true if the packet is a valid disco packet,
// and false otherwise.
// parseUDPPacket is a basic parser for UDP packets that returns the source and
// destination addresses, and the payload. The returned payload is a sub-slice
// of the input buffer.
//
// It will log the reason for the packet being invalid to logf; it is the
// caller's responsibility to control log verbosity.
func decodeDiscoPacket(pkt *packet.Parsed, logf logger.Logf, buf []byte, isIPv6 bool) bool {
// Do a quick length check before we parse the packet, so we can drop
// things that we know are too small.
var minSize int
// It expects to be called with a buffer that contains the entire UDP packet,
// including the IP header, and one that has been filtered with the BPF
// programs above.
//
// If an error occurs, it will return the zero values for all return values.
func parseUDPPacket(buf []byte, isIPv6 bool) (src, dst netip.AddrPort, payload []byte) {
// First, parse the IPv4 or IPv6 header to get to the UDP header. Since
// we assume this was filtered with BPF, we know that there will be no
// IPv6 extension headers.
var (
srcIP, dstIP netip.Addr
udp []byte
)
if isIPv6 {
minSize = ipv6.HeaderLen + udpHeaderSize + discoMinHeaderSize
// Basic length check to ensure that we don't panic
if len(buf) < ipv6.HeaderLen+udpHeaderSize {
return
}
// Extract the source and destination addresses from the IPv6
// header.
srcIP, _ = netip.AddrFromSlice(buf[8:24])
dstIP, _ = netip.AddrFromSlice(buf[24:40])
// We know that the UDP packet starts immediately after the IPv6
// packet.
udp = buf[ipv6.HeaderLen:]
} else {
minSize = ipv4.HeaderLen + udpHeaderSize + discoMinHeaderSize
}
if len(buf) < minSize {
logf("decodeDiscoPacket: received packet too small to be a disco packet: %d bytes < %d", len(buf), minSize)
return false
// This is an IPv4 packet; read the length field from the header.
if len(buf) < ipv4.HeaderLen {
return
}
udpOffset := int((buf[0] & 0x0F) << 2)
if udpOffset+udpHeaderSize > len(buf) {
return
}
// Parse the source and destination IPs.
srcIP, _ = netip.AddrFromSlice(buf[12:16])
dstIP, _ = netip.AddrFromSlice(buf[16:20])
udp = buf[udpOffset:]
}
// Parse the packet.
pkt.Decode(buf)
// Parse the ports
srcPort := binary.BigEndian.Uint16(udp[0:2])
dstPort := binary.BigEndian.Uint16(udp[2:4])
// Verify that this is a UDP packet.
if pkt.IPProto != ipproto.UDP {
logf("decodeDiscoPacket: received non-UDP packet: %d", pkt.IPProto)
return false
}
// Ensure that it's the right version of IP; given how we configure our
// listening sockets, we shouldn't ever get the wrong one, but it's
// best to confirm.
var wantVersion uint8
if isIPv6 {
wantVersion = 6
} else {
wantVersion = 4
}
if pkt.IPVersion != wantVersion {
logf("decodeDiscoPacket: received mismatched IP version %d (want %d)", pkt.IPVersion, wantVersion)
return false
}
return true
// The payload starts after the UDP header.
payload = udp[8:]
return netip.AddrPortFrom(srcIP, srcPort), netip.AddrPortFrom(dstIP, dstPort), payload
}
// ethernetProtoIPv4 returns the constant unix.ETH_P_IP, in network byte order.
@ -358,10 +376,7 @@ func (c *Conn) receiveDisco(pc *socket.Conn, isIPV6 bool) {
dlogf logger.Logf = logger.WithPrefix(c.discoLogf, prefix)
)
var (
buf [1500]byte
pkt packet.Parsed
)
var buf [1500]byte
for {
n, src, err := pc.Recvfrom(ctx, buf[:], 0)
if debugRawDiscoReads() {
@ -375,12 +390,13 @@ func (c *Conn) receiveDisco(pc *socket.Conn, isIPV6 bool) {
return
}
if !decodeDiscoPacket(&pkt, dlogf, buf[:n], isIPV6) {
srcAddr, dstAddr, payload := parseUDPPacket(buf[:n], family == "ip6")
if payload == nil {
// callee logged
continue
}
dstPort := pkt.Dst.Port()
dstPort := dstAddr.Port()
if dstPort == 0 {
logf("[unexpected] received packet for port 0")
}
@ -417,7 +433,7 @@ func (c *Conn) receiveDisco(pc *socket.Conn, isIPV6 bool) {
metricRecvDiscoPacketIPv4.Add(1)
}
c.handleDiscoMessage(pkt.Payload(), pkt.Src, key.NodePublic{}, discoRXPathRawSocket)
c.handleDiscoMessage(payload, srcAddr, key.NodePublic{}, discoRXPathRawSocket)
}
}

View File

@ -4,114 +4,112 @@
package magicsock
import (
"bytes"
"encoding/binary"
"net"
"net/netip"
"testing"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"golang.org/x/sys/cpu"
"golang.org/x/sys/unix"
"tailscale.com/disco"
"tailscale.com/net/packet"
"tailscale.com/types/ipproto"
)
func TestDecodeDiscoPacket(t *testing.T) {
mk4 := func(proto ipproto.Proto, src, dst netip.Addr, data []byte) []byte {
if !src.Is4() || !dst.Is4() {
panic("not an IPv4 address")
}
iph := &ipv4.Header{
Version: ipv4.Version,
Len: ipv4.HeaderLen,
TotalLen: ipv4.HeaderLen + len(data),
TTL: 64,
Protocol: int(proto),
Src: net.IP(src.AsSlice()),
Dst: net.IP(dst.AsSlice()),
}
hdr, err := iph.Marshal()
if err != nil {
panic(err)
}
return append(hdr, data...)
func TestParseUDPPacket(t *testing.T) {
src4 := netip.MustParseAddrPort("127.0.0.1:12345")
dst4 := netip.MustParseAddrPort("127.0.0.2:54321")
src6 := netip.MustParseAddrPort("[::1]:12345")
dst6 := netip.MustParseAddrPort("[::2]:54321")
udp4Packet := []byte{
// IPv4 header
0x45, 0x00, 0x00, 0x26, 0x00, 0x00, 0x00, 0x00,
0x40, 0x11, 0x00, 0x00,
0x7f, 0x00, 0x00, 0x01, // source ip
0x7f, 0x00, 0x00, 0x02, // dest ip
// UDP header
0x30, 0x39, // src port
0xd4, 0x31, // dest port
0x00, 0x12, // length; 8 bytes header + 10 bytes payload = 18 bytes
0x00, 0x00, // checksum; unused
// Payload: disco magic plus 4 bytes
0x54, 0x53, 0xf0, 0x9f, 0x92, 0xac, 0x00, 0x01, 0x02, 0x03,
}
mk6 := func(proto ipproto.Proto, src, dst netip.Addr, data []byte) []byte {
if !src.Is6() || !dst.Is6() {
panic("not an IPv6 address")
}
// The ipv6 package doesn't have a Marshal method, so just do
// the most basic thing and construct the header manually.
buf := make([]byte, ipv6.HeaderLen, ipv6.HeaderLen+len(data))
buf[0] = 6 << 4 // version
binary.BigEndian.PutUint16(buf[4:6], uint16(len(data)))
buf[6] = byte(proto)
copy(buf[8:24], src.AsSlice())
copy(buf[24:40], dst.AsSlice())
return append(buf, data...)
udp6Packet := []byte{
// IPv6 header
0x60, 0x00, 0x00, 0x00,
0x00, 0x12, // payload length
0x11, // next header: UDP
0x00, // hop limit; unused
// Source IP
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01,
// Dest IP
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02,
// UDP header
0x30, 0x39, // src port
0xd4, 0x31, // dest port
0x00, 0x12, // length; 8 bytes header + 10 bytes payload = 18 bytes
0x00, 0x00, // checksum; unused
// Payload: disco magic plus 4 bytes
0x54, 0x53, 0xf0, 0x9f, 0x92, 0xac, 0x00, 0x01, 0x02, 0x03,
}
mkUDP := func(srcPort, dstPort uint16, data []byte) []byte {
udp := make([]byte, 8, 8+len(data))
binary.BigEndian.PutUint16(udp[0:2], srcPort)
binary.BigEndian.PutUint16(udp[2:4], dstPort)
binary.BigEndian.PutUint16(udp[4:6], uint16(8+len(data)))
return append(udp, data...)
}
mkUDP4 := func(src, dst netip.AddrPort, data []byte) []byte {
return mk4(ipproto.UDP, src.Addr(), dst.Addr(), mkUDP(src.Port(), dst.Port(), data))
}
mkUDP6 := func(src, dst netip.AddrPort, data []byte) []byte {
return mk6(ipproto.UDP, src.Addr(), dst.Addr(), mkUDP(src.Port(), dst.Port(), data))
}
ip4 := netip.MustParseAddrPort("127.0.0.10:12345")
ip4_2 := netip.MustParseAddrPort("127.0.0.99:54321")
ip6 := netip.MustParseAddrPort("[::1]:12345")
testCases := []struct {
name string
in []byte
is6 bool
want bool
}{
{
name: "too_short_4",
in: mkUDP4(ip4, ip4_2, append([]byte(disco.Magic), 0x00, 0x00)),
is6: false,
want: false,
},
{
name: "too_short_6",
in: mkUDP6(ip6, ip6, append([]byte(disco.Magic), 0x00, 0x00)),
is6: true,
want: false,
},
{
name: "valid_4",
in: mkUDP4(ip4, ip4_2, testDiscoPacket),
is6: false,
want: true,
},
{
name: "valid_6",
in: mkUDP6(ip6, ip6, testDiscoPacket),
is6: true,
want: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var pkt packet.Parsed
got := decodeDiscoPacket(&pkt, t.Logf, tc.in, tc.is6)
if got != tc.want {
t.Errorf("got %v; want %v", got, tc.want)
// Verify that parsing the UDP packet works correctly.
t.Run("IPv4", func(t *testing.T) {
src, dst, payload := parseUDPPacket(udp4Packet, false)
if src != src4 {
t.Errorf("src = %v; want %v", src, src4)
}
if dst != dst4 {
t.Errorf("dst = %v; want %v", dst, dst4)
}
if !bytes.HasPrefix(payload, []byte(disco.Magic)) {
t.Errorf("payload = %x; must start with %x", payload, disco.Magic)
}
})
t.Run("IPv6", func(t *testing.T) {
src, dst, payload := parseUDPPacket(udp6Packet, true)
if src != src6 {
t.Errorf("src = %v; want %v", src, src6)
}
if dst != dst6 {
t.Errorf("dst = %v; want %v", dst, dst6)
}
if !bytes.HasPrefix(payload, []byte(disco.Magic)) {
t.Errorf("payload = %x; must start with %x", payload, disco.Magic)
}
})
t.Run("Truncated", func(t *testing.T) {
truncateBy := func(b []byte, n int) []byte {
if n >= len(b) {
return nil
}
})
}
return b[:len(b)-n]
}
src, dst, payload := parseUDPPacket(truncateBy(udp4Packet, 11), false)
if payload != nil {
t.Errorf("payload = %x; want nil", payload)
}
if src.IsValid() || dst.IsValid() {
t.Errorf("src = %v, dst = %v; want invalid", src, dst)
}
src, dst, payload = parseUDPPacket(truncateBy(udp6Packet, 11), true)
if payload != nil {
t.Errorf("payload = %x; want nil", payload)
}
if src.IsValid() || dst.IsValid() {
t.Errorf("src = %v, dst = %v; want invalid", src, dst)
}
})
}
func TestEthernetProto(t *testing.T) {