fixup! wgengine/magicsock: actually use AF_PACKET socket for raw disco

Signed-off-by: Andrew Dunham <andrew@du.nham.ca>
Change-Id: I2c71d7598b9e30df717329db7dc17cb4ad3f05f6
This commit is contained in:
Andrew Dunham
2024-08-15 13:16:40 -04:00
parent 35b91cb66a
commit 7dde340194
2 changed files with 158 additions and 144 deletions

View File

@@ -6,6 +6,7 @@ package magicsock
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/binary"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@@ -24,7 +25,6 @@ import (
"tailscale.com/disco" "tailscale.com/disco"
"tailscale.com/envknob" "tailscale.com/envknob"
"tailscale.com/net/netns" "tailscale.com/net/netns"
"tailscale.com/net/packet"
"tailscale.com/types/ipproto" "tailscale.com/types/ipproto"
"tailscale.com/types/key" "tailscale.com/types/key"
"tailscale.com/types/logger" "tailscale.com/types/logger"
@@ -51,6 +51,14 @@ var (
// receives the entire IPv4 packet, but not the Ethernet // receives the entire IPv4 packet, but not the Ethernet
// header. // header.
// Double-check that this is a UDP packet; we shouldn't be
// seeing anything else given how we create our AF_PACKET
// socket, but an extra check here is cheap, and matches the
// check that we do in the IPv6 path.
bpf.LoadAbsolute{Off: 9, Size: 1},
bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(ipproto.UDP), SkipTrue: 1, SkipFalse: 0},
bpf.RetConstant{Val: 0x0},
// Disco packets are so small they should never get // Disco packets are so small they should never get
// fragmented, and we don't want to handle reassembly. // fragmented, and we don't want to handle reassembly.
bpf.LoadAbsolute{Off: 6, Size: 2}, bpf.LoadAbsolute{Off: 6, Size: 2},
@@ -235,7 +243,6 @@ func (c *Conn) listenRawDisco(family string) (io.Closer, error) {
var ( var (
ctx = context.Background() ctx = context.Background()
buf [1500]byte buf [1500]byte
pkt packet.Parsed
) )
for { for {
n, _, err := sock.Recvfrom(ctx, buf[:], 0) n, _, err := sock.Recvfrom(ctx, buf[:], 0)
@@ -244,10 +251,11 @@ func (c *Conn) listenRawDisco(family string) (io.Closer, error) {
return nil, fmt.Errorf("reading during raw disco self-test: %w", err) return nil, fmt.Errorf("reading during raw disco self-test: %w", err)
} }
if !decodeDiscoPacket(&pkt, c.discoLogf, buf[:n], family == "ip6") { _ /* src */, _ /* dst */, payload := parseUDPPacket(buf[:n], family == "ip6")
if payload == nil {
continue continue
} }
if payload := pkt.Payload(); !bytes.Equal(payload, testDiscoPacket) { if !bytes.Equal(payload, testDiscoPacket) {
c.discoLogf("listenRawDisco: self-test: received mismatched UDP packet of %d bytes", len(payload)) c.discoLogf("listenRawDisco: self-test: received mismatched UDP packet of %d bytes", len(payload))
continue continue
} }
@@ -260,50 +268,60 @@ func (c *Conn) listenRawDisco(family string) (io.Closer, error) {
return sock, nil return sock, nil
} }
// decodeDiscoPacket decodes a disco packet from buf, using pkt as storage for // parseUDPPacket is a basic parser for UDP packets that returns the source and
// the parsed packet. It returns true if the packet is a valid disco packet, // destination addresses, and the payload. The returned payload is a sub-slice
// and false otherwise. // of the input buffer.
// //
// It will log the reason for the packet being invalid to logf; it is the // It expects to be called with a buffer that contains the entire UDP packet,
// caller's responsibility to control log verbosity. // including the IP header, and one that has been filtered with the BPF
func decodeDiscoPacket(pkt *packet.Parsed, logf logger.Logf, buf []byte, isIPv6 bool) bool { // programs above.
// Do a quick length check before we parse the packet, so we can drop //
// things that we know are too small. // If an error occurs, it will return the zero values for all return values.
var minSize int func parseUDPPacket(buf []byte, isIPv6 bool) (src, dst netip.AddrPort, payload []byte) {
// First, parse the IPv4 or IPv6 header to get to the UDP header. Since
// we assume this was filtered with BPF, we know that there will be no
// IPv6 extension headers.
var (
srcIP, dstIP netip.Addr
udp []byte
)
if isIPv6 { if isIPv6 {
minSize = ipv6.HeaderLen + udpHeaderSize + discoMinHeaderSize // Basic length check to ensure that we don't panic
if len(buf) < ipv6.HeaderLen+udpHeaderSize {
return
}
// Extract the source and destination addresses from the IPv6
// header.
srcIP, _ = netip.AddrFromSlice(buf[8:24])
dstIP, _ = netip.AddrFromSlice(buf[24:40])
// We know that the UDP packet starts immediately after the IPv6
// packet.
udp = buf[ipv6.HeaderLen:]
} else { } else {
minSize = ipv4.HeaderLen + udpHeaderSize + discoMinHeaderSize // This is an IPv4 packet; read the length field from the header.
if len(buf) < ipv4.HeaderLen {
return
} }
if len(buf) < minSize { udpOffset := int((buf[0] & 0x0F) << 2)
logf("decodeDiscoPacket: received packet too small to be a disco packet: %d bytes < %d", len(buf), minSize) if udpOffset+udpHeaderSize > len(buf) {
return false return
} }
// Parse the packet. // Parse the source and destination IPs.
pkt.Decode(buf) srcIP, _ = netip.AddrFromSlice(buf[12:16])
dstIP, _ = netip.AddrFromSlice(buf[16:20])
// Verify that this is a UDP packet. udp = buf[udpOffset:]
if pkt.IPProto != ipproto.UDP {
logf("decodeDiscoPacket: received non-UDP packet: %d", pkt.IPProto)
return false
} }
// Ensure that it's the right version of IP; given how we configure our // Parse the ports
// listening sockets, we shouldn't ever get the wrong one, but it's srcPort := binary.BigEndian.Uint16(udp[0:2])
// best to confirm. dstPort := binary.BigEndian.Uint16(udp[2:4])
var wantVersion uint8
if isIPv6 {
wantVersion = 6
} else {
wantVersion = 4
}
if pkt.IPVersion != wantVersion {
logf("decodeDiscoPacket: received mismatched IP version %d (want %d)", pkt.IPVersion, wantVersion)
return false
}
return true // The payload starts after the UDP header.
payload = udp[8:]
return netip.AddrPortFrom(srcIP, srcPort), netip.AddrPortFrom(dstIP, dstPort), payload
} }
// ethernetProtoIPv4 returns the constant unix.ETH_P_IP, in network byte order. // ethernetProtoIPv4 returns the constant unix.ETH_P_IP, in network byte order.
@@ -358,10 +376,7 @@ func (c *Conn) receiveDisco(pc *socket.Conn, isIPV6 bool) {
dlogf logger.Logf = logger.WithPrefix(c.discoLogf, prefix) dlogf logger.Logf = logger.WithPrefix(c.discoLogf, prefix)
) )
var ( var buf [1500]byte
buf [1500]byte
pkt packet.Parsed
)
for { for {
n, src, err := pc.Recvfrom(ctx, buf[:], 0) n, src, err := pc.Recvfrom(ctx, buf[:], 0)
if debugRawDiscoReads() { if debugRawDiscoReads() {
@@ -375,12 +390,13 @@ func (c *Conn) receiveDisco(pc *socket.Conn, isIPV6 bool) {
return return
} }
if !decodeDiscoPacket(&pkt, dlogf, buf[:n], isIPV6) { srcAddr, dstAddr, payload := parseUDPPacket(buf[:n], family == "ip6")
if payload == nil {
// callee logged // callee logged
continue continue
} }
dstPort := pkt.Dst.Port() dstPort := dstAddr.Port()
if dstPort == 0 { if dstPort == 0 {
logf("[unexpected] received packet for port 0") logf("[unexpected] received packet for port 0")
} }
@@ -417,7 +433,7 @@ func (c *Conn) receiveDisco(pc *socket.Conn, isIPV6 bool) {
metricRecvDiscoPacketIPv4.Add(1) metricRecvDiscoPacketIPv4.Add(1)
} }
c.handleDiscoMessage(pkt.Payload(), pkt.Src, key.NodePublic{}, discoRXPathRawSocket) c.handleDiscoMessage(payload, srcAddr, key.NodePublic{}, discoRXPathRawSocket)
} }
} }

View File

@@ -4,114 +4,112 @@
package magicsock package magicsock
import ( import (
"bytes"
"encoding/binary" "encoding/binary"
"net"
"net/netip" "net/netip"
"testing" "testing"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"golang.org/x/sys/cpu" "golang.org/x/sys/cpu"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"tailscale.com/disco" "tailscale.com/disco"
"tailscale.com/net/packet"
"tailscale.com/types/ipproto"
) )
func TestDecodeDiscoPacket(t *testing.T) { func TestParseUDPPacket(t *testing.T) {
mk4 := func(proto ipproto.Proto, src, dst netip.Addr, data []byte) []byte { src4 := netip.MustParseAddrPort("127.0.0.1:12345")
if !src.Is4() || !dst.Is4() { dst4 := netip.MustParseAddrPort("127.0.0.2:54321")
panic("not an IPv4 address")
src6 := netip.MustParseAddrPort("[::1]:12345")
dst6 := netip.MustParseAddrPort("[::2]:54321")
udp4Packet := []byte{
// IPv4 header
0x45, 0x00, 0x00, 0x26, 0x00, 0x00, 0x00, 0x00,
0x40, 0x11, 0x00, 0x00,
0x7f, 0x00, 0x00, 0x01, // source ip
0x7f, 0x00, 0x00, 0x02, // dest ip
// UDP header
0x30, 0x39, // src port
0xd4, 0x31, // dest port
0x00, 0x12, // length; 8 bytes header + 10 bytes payload = 18 bytes
0x00, 0x00, // checksum; unused
// Payload: disco magic plus 4 bytes
0x54, 0x53, 0xf0, 0x9f, 0x92, 0xac, 0x00, 0x01, 0x02, 0x03,
} }
iph := &ipv4.Header{ udp6Packet := []byte{
Version: ipv4.Version, // IPv6 header
Len: ipv4.HeaderLen, 0x60, 0x00, 0x00, 0x00,
TotalLen: ipv4.HeaderLen + len(data), 0x00, 0x12, // payload length
TTL: 64, 0x11, // next header: UDP
Protocol: int(proto), 0x00, // hop limit; unused
Src: net.IP(src.AsSlice()),
Dst: net.IP(dst.AsSlice()), // Source IP
} 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
hdr, err := iph.Marshal() 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01,
if err != nil { // Dest IP
panic(err) 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
} 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02,
return append(hdr, data...)
} // UDP header
mk6 := func(proto ipproto.Proto, src, dst netip.Addr, data []byte) []byte { 0x30, 0x39, // src port
if !src.Is6() || !dst.Is6() { 0xd4, 0x31, // dest port
panic("not an IPv6 address") 0x00, 0x12, // length; 8 bytes header + 10 bytes payload = 18 bytes
} 0x00, 0x00, // checksum; unused
// The ipv6 package doesn't have a Marshal method, so just do
// the most basic thing and construct the header manually. // Payload: disco magic plus 4 bytes
buf := make([]byte, ipv6.HeaderLen, ipv6.HeaderLen+len(data)) 0x54, 0x53, 0xf0, 0x9f, 0x92, 0xac, 0x00, 0x01, 0x02, 0x03,
buf[0] = 6 << 4 // version
binary.BigEndian.PutUint16(buf[4:6], uint16(len(data)))
buf[6] = byte(proto)
copy(buf[8:24], src.AsSlice())
copy(buf[24:40], dst.AsSlice())
return append(buf, data...)
} }
mkUDP := func(srcPort, dstPort uint16, data []byte) []byte { // Verify that parsing the UDP packet works correctly.
udp := make([]byte, 8, 8+len(data)) t.Run("IPv4", func(t *testing.T) {
binary.BigEndian.PutUint16(udp[0:2], srcPort) src, dst, payload := parseUDPPacket(udp4Packet, false)
binary.BigEndian.PutUint16(udp[2:4], dstPort) if src != src4 {
binary.BigEndian.PutUint16(udp[4:6], uint16(8+len(data))) t.Errorf("src = %v; want %v", src, src4)
return append(udp, data...)
} }
mkUDP4 := func(src, dst netip.AddrPort, data []byte) []byte { if dst != dst4 {
return mk4(ipproto.UDP, src.Addr(), dst.Addr(), mkUDP(src.Port(), dst.Port(), data)) t.Errorf("dst = %v; want %v", dst, dst4)
} }
mkUDP6 := func(src, dst netip.AddrPort, data []byte) []byte { if !bytes.HasPrefix(payload, []byte(disco.Magic)) {
return mk6(ipproto.UDP, src.Addr(), dst.Addr(), mkUDP(src.Port(), dst.Port(), data)) t.Errorf("payload = %x; must start with %x", payload, disco.Magic)
}
ip4 := netip.MustParseAddrPort("127.0.0.10:12345")
ip4_2 := netip.MustParseAddrPort("127.0.0.99:54321")
ip6 := netip.MustParseAddrPort("[::1]:12345")
testCases := []struct {
name string
in []byte
is6 bool
want bool
}{
{
name: "too_short_4",
in: mkUDP4(ip4, ip4_2, append([]byte(disco.Magic), 0x00, 0x00)),
is6: false,
want: false,
},
{
name: "too_short_6",
in: mkUDP6(ip6, ip6, append([]byte(disco.Magic), 0x00, 0x00)),
is6: true,
want: false,
},
{
name: "valid_4",
in: mkUDP4(ip4, ip4_2, testDiscoPacket),
is6: false,
want: true,
},
{
name: "valid_6",
in: mkUDP6(ip6, ip6, testDiscoPacket),
is6: true,
want: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var pkt packet.Parsed
got := decodeDiscoPacket(&pkt, t.Logf, tc.in, tc.is6)
if got != tc.want {
t.Errorf("got %v; want %v", got, tc.want)
} }
}) })
t.Run("IPv6", func(t *testing.T) {
src, dst, payload := parseUDPPacket(udp6Packet, true)
if src != src6 {
t.Errorf("src = %v; want %v", src, src6)
} }
if dst != dst6 {
t.Errorf("dst = %v; want %v", dst, dst6)
}
if !bytes.HasPrefix(payload, []byte(disco.Magic)) {
t.Errorf("payload = %x; must start with %x", payload, disco.Magic)
}
})
t.Run("Truncated", func(t *testing.T) {
truncateBy := func(b []byte, n int) []byte {
if n >= len(b) {
return nil
}
return b[:len(b)-n]
}
src, dst, payload := parseUDPPacket(truncateBy(udp4Packet, 11), false)
if payload != nil {
t.Errorf("payload = %x; want nil", payload)
}
if src.IsValid() || dst.IsValid() {
t.Errorf("src = %v, dst = %v; want invalid", src, dst)
}
src, dst, payload = parseUDPPacket(truncateBy(udp6Packet, 11), true)
if payload != nil {
t.Errorf("payload = %x; want nil", payload)
}
if src.IsValid() || dst.IsValid() {
t.Errorf("src = %v, dst = %v; want invalid", src, dst)
}
})
} }
func TestEthernetProto(t *testing.T) { func TestEthernetProto(t *testing.T) {