mirror of
https://github.com/tailscale/tailscale.git
synced 2025-10-24 17:48:57 +00:00
fixup! wgengine/magicsock: actually use AF_PACKET socket for raw disco
Signed-off-by: Andrew Dunham <andrew@du.nham.ca> Change-Id: I2c71d7598b9e30df717329db7dc17cb4ad3f05f6
This commit is contained in:
@@ -6,6 +6,7 @@ package magicsock
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -24,7 +25,6 @@ import (
|
||||
"tailscale.com/disco"
|
||||
"tailscale.com/envknob"
|
||||
"tailscale.com/net/netns"
|
||||
"tailscale.com/net/packet"
|
||||
"tailscale.com/types/ipproto"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/logger"
|
||||
@@ -51,6 +51,14 @@ var (
|
||||
// receives the entire IPv4 packet, but not the Ethernet
|
||||
// header.
|
||||
|
||||
// Double-check that this is a UDP packet; we shouldn't be
|
||||
// seeing anything else given how we create our AF_PACKET
|
||||
// socket, but an extra check here is cheap, and matches the
|
||||
// check that we do in the IPv6 path.
|
||||
bpf.LoadAbsolute{Off: 9, Size: 1},
|
||||
bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(ipproto.UDP), SkipTrue: 1, SkipFalse: 0},
|
||||
bpf.RetConstant{Val: 0x0},
|
||||
|
||||
// Disco packets are so small they should never get
|
||||
// fragmented, and we don't want to handle reassembly.
|
||||
bpf.LoadAbsolute{Off: 6, Size: 2},
|
||||
@@ -235,7 +243,6 @@ func (c *Conn) listenRawDisco(family string) (io.Closer, error) {
|
||||
var (
|
||||
ctx = context.Background()
|
||||
buf [1500]byte
|
||||
pkt packet.Parsed
|
||||
)
|
||||
for {
|
||||
n, _, err := sock.Recvfrom(ctx, buf[:], 0)
|
||||
@@ -244,10 +251,11 @@ func (c *Conn) listenRawDisco(family string) (io.Closer, error) {
|
||||
return nil, fmt.Errorf("reading during raw disco self-test: %w", err)
|
||||
}
|
||||
|
||||
if !decodeDiscoPacket(&pkt, c.discoLogf, buf[:n], family == "ip6") {
|
||||
_ /* src */, _ /* dst */, payload := parseUDPPacket(buf[:n], family == "ip6")
|
||||
if payload == nil {
|
||||
continue
|
||||
}
|
||||
if payload := pkt.Payload(); !bytes.Equal(payload, testDiscoPacket) {
|
||||
if !bytes.Equal(payload, testDiscoPacket) {
|
||||
c.discoLogf("listenRawDisco: self-test: received mismatched UDP packet of %d bytes", len(payload))
|
||||
continue
|
||||
}
|
||||
@@ -260,50 +268,60 @@ func (c *Conn) listenRawDisco(family string) (io.Closer, error) {
|
||||
return sock, nil
|
||||
}
|
||||
|
||||
// decodeDiscoPacket decodes a disco packet from buf, using pkt as storage for
|
||||
// the parsed packet. It returns true if the packet is a valid disco packet,
|
||||
// and false otherwise.
|
||||
// parseUDPPacket is a basic parser for UDP packets that returns the source and
|
||||
// destination addresses, and the payload. The returned payload is a sub-slice
|
||||
// of the input buffer.
|
||||
//
|
||||
// It will log the reason for the packet being invalid to logf; it is the
|
||||
// caller's responsibility to control log verbosity.
|
||||
func decodeDiscoPacket(pkt *packet.Parsed, logf logger.Logf, buf []byte, isIPv6 bool) bool {
|
||||
// Do a quick length check before we parse the packet, so we can drop
|
||||
// things that we know are too small.
|
||||
var minSize int
|
||||
// It expects to be called with a buffer that contains the entire UDP packet,
|
||||
// including the IP header, and one that has been filtered with the BPF
|
||||
// programs above.
|
||||
//
|
||||
// If an error occurs, it will return the zero values for all return values.
|
||||
func parseUDPPacket(buf []byte, isIPv6 bool) (src, dst netip.AddrPort, payload []byte) {
|
||||
// First, parse the IPv4 or IPv6 header to get to the UDP header. Since
|
||||
// we assume this was filtered with BPF, we know that there will be no
|
||||
// IPv6 extension headers.
|
||||
var (
|
||||
srcIP, dstIP netip.Addr
|
||||
udp []byte
|
||||
)
|
||||
if isIPv6 {
|
||||
minSize = ipv6.HeaderLen + udpHeaderSize + discoMinHeaderSize
|
||||
// Basic length check to ensure that we don't panic
|
||||
if len(buf) < ipv6.HeaderLen+udpHeaderSize {
|
||||
return
|
||||
}
|
||||
|
||||
// Extract the source and destination addresses from the IPv6
|
||||
// header.
|
||||
srcIP, _ = netip.AddrFromSlice(buf[8:24])
|
||||
dstIP, _ = netip.AddrFromSlice(buf[24:40])
|
||||
|
||||
// We know that the UDP packet starts immediately after the IPv6
|
||||
// packet.
|
||||
udp = buf[ipv6.HeaderLen:]
|
||||
} else {
|
||||
minSize = ipv4.HeaderLen + udpHeaderSize + discoMinHeaderSize
|
||||
}
|
||||
if len(buf) < minSize {
|
||||
logf("decodeDiscoPacket: received packet too small to be a disco packet: %d bytes < %d", len(buf), minSize)
|
||||
return false
|
||||
// This is an IPv4 packet; read the length field from the header.
|
||||
if len(buf) < ipv4.HeaderLen {
|
||||
return
|
||||
}
|
||||
udpOffset := int((buf[0] & 0x0F) << 2)
|
||||
if udpOffset+udpHeaderSize > len(buf) {
|
||||
return
|
||||
}
|
||||
|
||||
// Parse the source and destination IPs.
|
||||
srcIP, _ = netip.AddrFromSlice(buf[12:16])
|
||||
dstIP, _ = netip.AddrFromSlice(buf[16:20])
|
||||
udp = buf[udpOffset:]
|
||||
}
|
||||
|
||||
// Parse the packet.
|
||||
pkt.Decode(buf)
|
||||
// Parse the ports
|
||||
srcPort := binary.BigEndian.Uint16(udp[0:2])
|
||||
dstPort := binary.BigEndian.Uint16(udp[2:4])
|
||||
|
||||
// Verify that this is a UDP packet.
|
||||
if pkt.IPProto != ipproto.UDP {
|
||||
logf("decodeDiscoPacket: received non-UDP packet: %d", pkt.IPProto)
|
||||
return false
|
||||
}
|
||||
|
||||
// Ensure that it's the right version of IP; given how we configure our
|
||||
// listening sockets, we shouldn't ever get the wrong one, but it's
|
||||
// best to confirm.
|
||||
var wantVersion uint8
|
||||
if isIPv6 {
|
||||
wantVersion = 6
|
||||
} else {
|
||||
wantVersion = 4
|
||||
}
|
||||
if pkt.IPVersion != wantVersion {
|
||||
logf("decodeDiscoPacket: received mismatched IP version %d (want %d)", pkt.IPVersion, wantVersion)
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
// The payload starts after the UDP header.
|
||||
payload = udp[8:]
|
||||
return netip.AddrPortFrom(srcIP, srcPort), netip.AddrPortFrom(dstIP, dstPort), payload
|
||||
}
|
||||
|
||||
// ethernetProtoIPv4 returns the constant unix.ETH_P_IP, in network byte order.
|
||||
@@ -358,10 +376,7 @@ func (c *Conn) receiveDisco(pc *socket.Conn, isIPV6 bool) {
|
||||
dlogf logger.Logf = logger.WithPrefix(c.discoLogf, prefix)
|
||||
)
|
||||
|
||||
var (
|
||||
buf [1500]byte
|
||||
pkt packet.Parsed
|
||||
)
|
||||
var buf [1500]byte
|
||||
for {
|
||||
n, src, err := pc.Recvfrom(ctx, buf[:], 0)
|
||||
if debugRawDiscoReads() {
|
||||
@@ -375,12 +390,13 @@ func (c *Conn) receiveDisco(pc *socket.Conn, isIPV6 bool) {
|
||||
return
|
||||
}
|
||||
|
||||
if !decodeDiscoPacket(&pkt, dlogf, buf[:n], isIPV6) {
|
||||
srcAddr, dstAddr, payload := parseUDPPacket(buf[:n], family == "ip6")
|
||||
if payload == nil {
|
||||
// callee logged
|
||||
continue
|
||||
}
|
||||
|
||||
dstPort := pkt.Dst.Port()
|
||||
dstPort := dstAddr.Port()
|
||||
if dstPort == 0 {
|
||||
logf("[unexpected] received packet for port 0")
|
||||
}
|
||||
@@ -417,7 +433,7 @@ func (c *Conn) receiveDisco(pc *socket.Conn, isIPV6 bool) {
|
||||
metricRecvDiscoPacketIPv4.Add(1)
|
||||
}
|
||||
|
||||
c.handleDiscoMessage(pkt.Payload(), pkt.Src, key.NodePublic{}, discoRXPathRawSocket)
|
||||
c.handleDiscoMessage(payload, srcAddr, key.NodePublic{}, discoRXPathRawSocket)
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -4,114 +4,112 @@
|
||||
package magicsock
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
"golang.org/x/sys/cpu"
|
||||
"golang.org/x/sys/unix"
|
||||
"tailscale.com/disco"
|
||||
"tailscale.com/net/packet"
|
||||
"tailscale.com/types/ipproto"
|
||||
)
|
||||
|
||||
func TestDecodeDiscoPacket(t *testing.T) {
|
||||
mk4 := func(proto ipproto.Proto, src, dst netip.Addr, data []byte) []byte {
|
||||
if !src.Is4() || !dst.Is4() {
|
||||
panic("not an IPv4 address")
|
||||
}
|
||||
iph := &ipv4.Header{
|
||||
Version: ipv4.Version,
|
||||
Len: ipv4.HeaderLen,
|
||||
TotalLen: ipv4.HeaderLen + len(data),
|
||||
TTL: 64,
|
||||
Protocol: int(proto),
|
||||
Src: net.IP(src.AsSlice()),
|
||||
Dst: net.IP(dst.AsSlice()),
|
||||
}
|
||||
hdr, err := iph.Marshal()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return append(hdr, data...)
|
||||
func TestParseUDPPacket(t *testing.T) {
|
||||
src4 := netip.MustParseAddrPort("127.0.0.1:12345")
|
||||
dst4 := netip.MustParseAddrPort("127.0.0.2:54321")
|
||||
|
||||
src6 := netip.MustParseAddrPort("[::1]:12345")
|
||||
dst6 := netip.MustParseAddrPort("[::2]:54321")
|
||||
|
||||
udp4Packet := []byte{
|
||||
// IPv4 header
|
||||
0x45, 0x00, 0x00, 0x26, 0x00, 0x00, 0x00, 0x00,
|
||||
0x40, 0x11, 0x00, 0x00,
|
||||
0x7f, 0x00, 0x00, 0x01, // source ip
|
||||
0x7f, 0x00, 0x00, 0x02, // dest ip
|
||||
|
||||
// UDP header
|
||||
0x30, 0x39, // src port
|
||||
0xd4, 0x31, // dest port
|
||||
0x00, 0x12, // length; 8 bytes header + 10 bytes payload = 18 bytes
|
||||
0x00, 0x00, // checksum; unused
|
||||
|
||||
// Payload: disco magic plus 4 bytes
|
||||
0x54, 0x53, 0xf0, 0x9f, 0x92, 0xac, 0x00, 0x01, 0x02, 0x03,
|
||||
}
|
||||
mk6 := func(proto ipproto.Proto, src, dst netip.Addr, data []byte) []byte {
|
||||
if !src.Is6() || !dst.Is6() {
|
||||
panic("not an IPv6 address")
|
||||
}
|
||||
// The ipv6 package doesn't have a Marshal method, so just do
|
||||
// the most basic thing and construct the header manually.
|
||||
buf := make([]byte, ipv6.HeaderLen, ipv6.HeaderLen+len(data))
|
||||
buf[0] = 6 << 4 // version
|
||||
binary.BigEndian.PutUint16(buf[4:6], uint16(len(data)))
|
||||
buf[6] = byte(proto)
|
||||
copy(buf[8:24], src.AsSlice())
|
||||
copy(buf[24:40], dst.AsSlice())
|
||||
return append(buf, data...)
|
||||
udp6Packet := []byte{
|
||||
// IPv6 header
|
||||
0x60, 0x00, 0x00, 0x00,
|
||||
0x00, 0x12, // payload length
|
||||
0x11, // next header: UDP
|
||||
0x00, // hop limit; unused
|
||||
|
||||
// Source IP
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01,
|
||||
// Dest IP
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02,
|
||||
|
||||
// UDP header
|
||||
0x30, 0x39, // src port
|
||||
0xd4, 0x31, // dest port
|
||||
0x00, 0x12, // length; 8 bytes header + 10 bytes payload = 18 bytes
|
||||
0x00, 0x00, // checksum; unused
|
||||
|
||||
// Payload: disco magic plus 4 bytes
|
||||
0x54, 0x53, 0xf0, 0x9f, 0x92, 0xac, 0x00, 0x01, 0x02, 0x03,
|
||||
}
|
||||
|
||||
mkUDP := func(srcPort, dstPort uint16, data []byte) []byte {
|
||||
udp := make([]byte, 8, 8+len(data))
|
||||
binary.BigEndian.PutUint16(udp[0:2], srcPort)
|
||||
binary.BigEndian.PutUint16(udp[2:4], dstPort)
|
||||
binary.BigEndian.PutUint16(udp[4:6], uint16(8+len(data)))
|
||||
return append(udp, data...)
|
||||
}
|
||||
mkUDP4 := func(src, dst netip.AddrPort, data []byte) []byte {
|
||||
return mk4(ipproto.UDP, src.Addr(), dst.Addr(), mkUDP(src.Port(), dst.Port(), data))
|
||||
}
|
||||
mkUDP6 := func(src, dst netip.AddrPort, data []byte) []byte {
|
||||
return mk6(ipproto.UDP, src.Addr(), dst.Addr(), mkUDP(src.Port(), dst.Port(), data))
|
||||
}
|
||||
|
||||
ip4 := netip.MustParseAddrPort("127.0.0.10:12345")
|
||||
ip4_2 := netip.MustParseAddrPort("127.0.0.99:54321")
|
||||
ip6 := netip.MustParseAddrPort("[::1]:12345")
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
in []byte
|
||||
is6 bool
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "too_short_4",
|
||||
in: mkUDP4(ip4, ip4_2, append([]byte(disco.Magic), 0x00, 0x00)),
|
||||
is6: false,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "too_short_6",
|
||||
in: mkUDP6(ip6, ip6, append([]byte(disco.Magic), 0x00, 0x00)),
|
||||
is6: true,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "valid_4",
|
||||
in: mkUDP4(ip4, ip4_2, testDiscoPacket),
|
||||
is6: false,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "valid_6",
|
||||
in: mkUDP6(ip6, ip6, testDiscoPacket),
|
||||
is6: true,
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var pkt packet.Parsed
|
||||
got := decodeDiscoPacket(&pkt, t.Logf, tc.in, tc.is6)
|
||||
if got != tc.want {
|
||||
t.Errorf("got %v; want %v", got, tc.want)
|
||||
// Verify that parsing the UDP packet works correctly.
|
||||
t.Run("IPv4", func(t *testing.T) {
|
||||
src, dst, payload := parseUDPPacket(udp4Packet, false)
|
||||
if src != src4 {
|
||||
t.Errorf("src = %v; want %v", src, src4)
|
||||
}
|
||||
if dst != dst4 {
|
||||
t.Errorf("dst = %v; want %v", dst, dst4)
|
||||
}
|
||||
if !bytes.HasPrefix(payload, []byte(disco.Magic)) {
|
||||
t.Errorf("payload = %x; must start with %x", payload, disco.Magic)
|
||||
}
|
||||
})
|
||||
t.Run("IPv6", func(t *testing.T) {
|
||||
src, dst, payload := parseUDPPacket(udp6Packet, true)
|
||||
if src != src6 {
|
||||
t.Errorf("src = %v; want %v", src, src6)
|
||||
}
|
||||
if dst != dst6 {
|
||||
t.Errorf("dst = %v; want %v", dst, dst6)
|
||||
}
|
||||
if !bytes.HasPrefix(payload, []byte(disco.Magic)) {
|
||||
t.Errorf("payload = %x; must start with %x", payload, disco.Magic)
|
||||
}
|
||||
})
|
||||
t.Run("Truncated", func(t *testing.T) {
|
||||
truncateBy := func(b []byte, n int) []byte {
|
||||
if n >= len(b) {
|
||||
return nil
|
||||
}
|
||||
})
|
||||
}
|
||||
return b[:len(b)-n]
|
||||
}
|
||||
|
||||
src, dst, payload := parseUDPPacket(truncateBy(udp4Packet, 11), false)
|
||||
if payload != nil {
|
||||
t.Errorf("payload = %x; want nil", payload)
|
||||
}
|
||||
if src.IsValid() || dst.IsValid() {
|
||||
t.Errorf("src = %v, dst = %v; want invalid", src, dst)
|
||||
}
|
||||
|
||||
src, dst, payload = parseUDPPacket(truncateBy(udp6Packet, 11), true)
|
||||
if payload != nil {
|
||||
t.Errorf("payload = %x; want nil", payload)
|
||||
}
|
||||
if src.IsValid() || dst.IsValid() {
|
||||
t.Errorf("src = %v, dst = %v; want invalid", src, dst)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestEthernetProto(t *testing.T) {
|
||||
|
Reference in New Issue
Block a user