diff --git a/wgengine/magicsock/magicsock_linux.go b/wgengine/magicsock/magicsock_linux.go index ebb0988d1..d8a8ad74f 100644 --- a/wgengine/magicsock/magicsock_linux.go +++ b/wgengine/magicsock/magicsock_linux.go @@ -6,6 +6,7 @@ package magicsock import ( "bytes" "context" + "encoding/binary" "errors" "fmt" "io" @@ -24,7 +25,6 @@ import ( "tailscale.com/disco" "tailscale.com/envknob" "tailscale.com/net/netns" - "tailscale.com/net/packet" "tailscale.com/types/ipproto" "tailscale.com/types/key" "tailscale.com/types/logger" @@ -51,6 +51,14 @@ var ( // receives the entire IPv4 packet, but not the Ethernet // header. + // Double-check that this is a UDP packet; we shouldn't be + // seeing anything else given how we create our AF_PACKET + // socket, but an extra check here is cheap, and matches the + // check that we do in the IPv6 path. + bpf.LoadAbsolute{Off: 9, Size: 1}, + bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(ipproto.UDP), SkipTrue: 1, SkipFalse: 0}, + bpf.RetConstant{Val: 0x0}, + // Disco packets are so small they should never get // fragmented, and we don't want to handle reassembly. bpf.LoadAbsolute{Off: 6, Size: 2}, @@ -235,7 +243,6 @@ func (c *Conn) listenRawDisco(family string) (io.Closer, error) { var ( ctx = context.Background() buf [1500]byte - pkt packet.Parsed ) for { n, _, err := sock.Recvfrom(ctx, buf[:], 0) @@ -244,10 +251,11 @@ func (c *Conn) listenRawDisco(family string) (io.Closer, error) { return nil, fmt.Errorf("reading during raw disco self-test: %w", err) } - if !decodeDiscoPacket(&pkt, c.discoLogf, buf[:n], family == "ip6") { + _ /* src */, _ /* dst */, payload := parseUDPPacket(buf[:n], family == "ip6") + if payload == nil { continue } - if payload := pkt.Payload(); !bytes.Equal(payload, testDiscoPacket) { + if !bytes.Equal(payload, testDiscoPacket) { c.discoLogf("listenRawDisco: self-test: received mismatched UDP packet of %d bytes", len(payload)) continue } @@ -260,50 +268,60 @@ func (c *Conn) listenRawDisco(family string) (io.Closer, error) { return sock, nil } -// decodeDiscoPacket decodes a disco packet from buf, using pkt as storage for -// the parsed packet. It returns true if the packet is a valid disco packet, -// and false otherwise. +// parseUDPPacket is a basic parser for UDP packets that returns the source and +// destination addresses, and the payload. The returned payload is a sub-slice +// of the input buffer. // -// It will log the reason for the packet being invalid to logf; it is the -// caller's responsibility to control log verbosity. -func decodeDiscoPacket(pkt *packet.Parsed, logf logger.Logf, buf []byte, isIPv6 bool) bool { - // Do a quick length check before we parse the packet, so we can drop - // things that we know are too small. - var minSize int +// It expects to be called with a buffer that contains the entire UDP packet, +// including the IP header, and one that has been filtered with the BPF +// programs above. +// +// If an error occurs, it will return the zero values for all return values. +func parseUDPPacket(buf []byte, isIPv6 bool) (src, dst netip.AddrPort, payload []byte) { + // First, parse the IPv4 or IPv6 header to get to the UDP header. Since + // we assume this was filtered with BPF, we know that there will be no + // IPv6 extension headers. + var ( + srcIP, dstIP netip.Addr + udp []byte + ) if isIPv6 { - minSize = ipv6.HeaderLen + udpHeaderSize + discoMinHeaderSize + // Basic length check to ensure that we don't panic + if len(buf) < ipv6.HeaderLen+udpHeaderSize { + return + } + + // Extract the source and destination addresses from the IPv6 + // header. + srcIP, _ = netip.AddrFromSlice(buf[8:24]) + dstIP, _ = netip.AddrFromSlice(buf[24:40]) + + // We know that the UDP packet starts immediately after the IPv6 + // packet. + udp = buf[ipv6.HeaderLen:] } else { - minSize = ipv4.HeaderLen + udpHeaderSize + discoMinHeaderSize - } - if len(buf) < minSize { - logf("decodeDiscoPacket: received packet too small to be a disco packet: %d bytes < %d", len(buf), minSize) - return false + // This is an IPv4 packet; read the length field from the header. + if len(buf) < ipv4.HeaderLen { + return + } + udpOffset := int((buf[0] & 0x0F) << 2) + if udpOffset+udpHeaderSize > len(buf) { + return + } + + // Parse the source and destination IPs. + srcIP, _ = netip.AddrFromSlice(buf[12:16]) + dstIP, _ = netip.AddrFromSlice(buf[16:20]) + udp = buf[udpOffset:] } - // Parse the packet. - pkt.Decode(buf) + // Parse the ports + srcPort := binary.BigEndian.Uint16(udp[0:2]) + dstPort := binary.BigEndian.Uint16(udp[2:4]) - // Verify that this is a UDP packet. - if pkt.IPProto != ipproto.UDP { - logf("decodeDiscoPacket: received non-UDP packet: %d", pkt.IPProto) - return false - } - - // Ensure that it's the right version of IP; given how we configure our - // listening sockets, we shouldn't ever get the wrong one, but it's - // best to confirm. - var wantVersion uint8 - if isIPv6 { - wantVersion = 6 - } else { - wantVersion = 4 - } - if pkt.IPVersion != wantVersion { - logf("decodeDiscoPacket: received mismatched IP version %d (want %d)", pkt.IPVersion, wantVersion) - return false - } - - return true + // The payload starts after the UDP header. + payload = udp[8:] + return netip.AddrPortFrom(srcIP, srcPort), netip.AddrPortFrom(dstIP, dstPort), payload } // ethernetProtoIPv4 returns the constant unix.ETH_P_IP, in network byte order. @@ -358,10 +376,7 @@ func (c *Conn) receiveDisco(pc *socket.Conn, isIPV6 bool) { dlogf logger.Logf = logger.WithPrefix(c.discoLogf, prefix) ) - var ( - buf [1500]byte - pkt packet.Parsed - ) + var buf [1500]byte for { n, src, err := pc.Recvfrom(ctx, buf[:], 0) if debugRawDiscoReads() { @@ -375,12 +390,13 @@ func (c *Conn) receiveDisco(pc *socket.Conn, isIPV6 bool) { return } - if !decodeDiscoPacket(&pkt, dlogf, buf[:n], isIPV6) { + srcAddr, dstAddr, payload := parseUDPPacket(buf[:n], family == "ip6") + if payload == nil { // callee logged continue } - dstPort := pkt.Dst.Port() + dstPort := dstAddr.Port() if dstPort == 0 { logf("[unexpected] received packet for port 0") } @@ -417,7 +433,7 @@ func (c *Conn) receiveDisco(pc *socket.Conn, isIPV6 bool) { metricRecvDiscoPacketIPv4.Add(1) } - c.handleDiscoMessage(pkt.Payload(), pkt.Src, key.NodePublic{}, discoRXPathRawSocket) + c.handleDiscoMessage(payload, srcAddr, key.NodePublic{}, discoRXPathRawSocket) } } diff --git a/wgengine/magicsock/magicsock_linux_test.go b/wgengine/magicsock/magicsock_linux_test.go index e9e7d73d8..6b86b04f2 100644 --- a/wgengine/magicsock/magicsock_linux_test.go +++ b/wgengine/magicsock/magicsock_linux_test.go @@ -4,114 +4,112 @@ package magicsock import ( + "bytes" "encoding/binary" - "net" "net/netip" "testing" - "golang.org/x/net/ipv4" - "golang.org/x/net/ipv6" "golang.org/x/sys/cpu" "golang.org/x/sys/unix" "tailscale.com/disco" - "tailscale.com/net/packet" - "tailscale.com/types/ipproto" ) -func TestDecodeDiscoPacket(t *testing.T) { - mk4 := func(proto ipproto.Proto, src, dst netip.Addr, data []byte) []byte { - if !src.Is4() || !dst.Is4() { - panic("not an IPv4 address") - } - iph := &ipv4.Header{ - Version: ipv4.Version, - Len: ipv4.HeaderLen, - TotalLen: ipv4.HeaderLen + len(data), - TTL: 64, - Protocol: int(proto), - Src: net.IP(src.AsSlice()), - Dst: net.IP(dst.AsSlice()), - } - hdr, err := iph.Marshal() - if err != nil { - panic(err) - } - return append(hdr, data...) +func TestParseUDPPacket(t *testing.T) { + src4 := netip.MustParseAddrPort("127.0.0.1:12345") + dst4 := netip.MustParseAddrPort("127.0.0.2:54321") + + src6 := netip.MustParseAddrPort("[::1]:12345") + dst6 := netip.MustParseAddrPort("[::2]:54321") + + udp4Packet := []byte{ + // IPv4 header + 0x45, 0x00, 0x00, 0x26, 0x00, 0x00, 0x00, 0x00, + 0x40, 0x11, 0x00, 0x00, + 0x7f, 0x00, 0x00, 0x01, // source ip + 0x7f, 0x00, 0x00, 0x02, // dest ip + + // UDP header + 0x30, 0x39, // src port + 0xd4, 0x31, // dest port + 0x00, 0x12, // length; 8 bytes header + 10 bytes payload = 18 bytes + 0x00, 0x00, // checksum; unused + + // Payload: disco magic plus 4 bytes + 0x54, 0x53, 0xf0, 0x9f, 0x92, 0xac, 0x00, 0x01, 0x02, 0x03, } - mk6 := func(proto ipproto.Proto, src, dst netip.Addr, data []byte) []byte { - if !src.Is6() || !dst.Is6() { - panic("not an IPv6 address") - } - // The ipv6 package doesn't have a Marshal method, so just do - // the most basic thing and construct the header manually. - buf := make([]byte, ipv6.HeaderLen, ipv6.HeaderLen+len(data)) - buf[0] = 6 << 4 // version - binary.BigEndian.PutUint16(buf[4:6], uint16(len(data))) - buf[6] = byte(proto) - copy(buf[8:24], src.AsSlice()) - copy(buf[24:40], dst.AsSlice()) - return append(buf, data...) + udp6Packet := []byte{ + // IPv6 header + 0x60, 0x00, 0x00, 0x00, + 0x00, 0x12, // payload length + 0x11, // next header: UDP + 0x00, // hop limit; unused + + // Source IP + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + // Dest IP + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, + + // UDP header + 0x30, 0x39, // src port + 0xd4, 0x31, // dest port + 0x00, 0x12, // length; 8 bytes header + 10 bytes payload = 18 bytes + 0x00, 0x00, // checksum; unused + + // Payload: disco magic plus 4 bytes + 0x54, 0x53, 0xf0, 0x9f, 0x92, 0xac, 0x00, 0x01, 0x02, 0x03, } - mkUDP := func(srcPort, dstPort uint16, data []byte) []byte { - udp := make([]byte, 8, 8+len(data)) - binary.BigEndian.PutUint16(udp[0:2], srcPort) - binary.BigEndian.PutUint16(udp[2:4], dstPort) - binary.BigEndian.PutUint16(udp[4:6], uint16(8+len(data))) - return append(udp, data...) - } - mkUDP4 := func(src, dst netip.AddrPort, data []byte) []byte { - return mk4(ipproto.UDP, src.Addr(), dst.Addr(), mkUDP(src.Port(), dst.Port(), data)) - } - mkUDP6 := func(src, dst netip.AddrPort, data []byte) []byte { - return mk6(ipproto.UDP, src.Addr(), dst.Addr(), mkUDP(src.Port(), dst.Port(), data)) - } - - ip4 := netip.MustParseAddrPort("127.0.0.10:12345") - ip4_2 := netip.MustParseAddrPort("127.0.0.99:54321") - ip6 := netip.MustParseAddrPort("[::1]:12345") - - testCases := []struct { - name string - in []byte - is6 bool - want bool - }{ - { - name: "too_short_4", - in: mkUDP4(ip4, ip4_2, append([]byte(disco.Magic), 0x00, 0x00)), - is6: false, - want: false, - }, - { - name: "too_short_6", - in: mkUDP6(ip6, ip6, append([]byte(disco.Magic), 0x00, 0x00)), - is6: true, - want: false, - }, - { - name: "valid_4", - in: mkUDP4(ip4, ip4_2, testDiscoPacket), - is6: false, - want: true, - }, - { - name: "valid_6", - in: mkUDP6(ip6, ip6, testDiscoPacket), - is6: true, - want: true, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - var pkt packet.Parsed - got := decodeDiscoPacket(&pkt, t.Logf, tc.in, tc.is6) - if got != tc.want { - t.Errorf("got %v; want %v", got, tc.want) + // Verify that parsing the UDP packet works correctly. + t.Run("IPv4", func(t *testing.T) { + src, dst, payload := parseUDPPacket(udp4Packet, false) + if src != src4 { + t.Errorf("src = %v; want %v", src, src4) + } + if dst != dst4 { + t.Errorf("dst = %v; want %v", dst, dst4) + } + if !bytes.HasPrefix(payload, []byte(disco.Magic)) { + t.Errorf("payload = %x; must start with %x", payload, disco.Magic) + } + }) + t.Run("IPv6", func(t *testing.T) { + src, dst, payload := parseUDPPacket(udp6Packet, true) + if src != src6 { + t.Errorf("src = %v; want %v", src, src6) + } + if dst != dst6 { + t.Errorf("dst = %v; want %v", dst, dst6) + } + if !bytes.HasPrefix(payload, []byte(disco.Magic)) { + t.Errorf("payload = %x; must start with %x", payload, disco.Magic) + } + }) + t.Run("Truncated", func(t *testing.T) { + truncateBy := func(b []byte, n int) []byte { + if n >= len(b) { + return nil } - }) - } + return b[:len(b)-n] + } + + src, dst, payload := parseUDPPacket(truncateBy(udp4Packet, 11), false) + if payload != nil { + t.Errorf("payload = %x; want nil", payload) + } + if src.IsValid() || dst.IsValid() { + t.Errorf("src = %v, dst = %v; want invalid", src, dst) + } + + src, dst, payload = parseUDPPacket(truncateBy(udp6Packet, 11), true) + if payload != nil { + t.Errorf("payload = %x; want nil", payload) + } + if src.IsValid() || dst.IsValid() { + t.Errorf("src = %v, dst = %v; want invalid", src, dst) + } + }) } func TestEthernetProto(t *testing.T) {