mirror of
				https://github.com/tailscale/tailscale.git
				synced 2025-10-26 10:39:19 +00:00 
			
		
		
		
	wgengine/magicsock: actually use AF_PACKET socket for raw disco
Previously, despite what the commit said, we were using a raw IP socket that was *not* an AF_PACKET socket, and thus was subject to the host firewall rules. Switch to using a real AF_PACKET socket to actually get the functionality we want. Updates #13140 Signed-off-by: Andrew Dunham <andrew@du.nham.ca> Change-Id: If657daeeda9ab8d967e75a4f049c66e2bca54b78
This commit is contained in:
		| @@ -5,7 +5,7 @@ package magicsock | ||||
| 
 | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"encoding/binary" | ||||
| 	"context" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| @@ -13,20 +13,28 @@ import ( | ||||
| 	"net/netip" | ||||
| 	"syscall" | ||||
| 	"time" | ||||
| 	"unsafe" | ||||
| 
 | ||||
| 	"github.com/mdlayher/socket" | ||||
| 	"golang.org/x/net/bpf" | ||||
| 	"golang.org/x/net/ipv4" | ||||
| 	"golang.org/x/net/ipv6" | ||||
| 	"golang.org/x/sys/cpu" | ||||
| 	"golang.org/x/sys/unix" | ||||
| 	"tailscale.com/disco" | ||||
| 	"tailscale.com/envknob" | ||||
| 	"tailscale.com/net/netns" | ||||
| 	"tailscale.com/net/packet" | ||||
| 	"tailscale.com/types/ipproto" | ||||
| 	"tailscale.com/types/key" | ||||
| 	"tailscale.com/types/logger" | ||||
| 	"tailscale.com/types/nettype" | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
| 	udpHeaderSize          = 8 | ||||
| 	ipv6FragmentHeaderSize = 8 | ||||
| 	udpHeaderSize = 8 | ||||
| 
 | ||||
| 	// discoMinHeaderSize is the minimum size of the disco header in bytes. | ||||
| 	discoMinHeaderSize = len(disco.Magic) + 32 /* key length */ + disco.NonceLen | ||||
| ) | ||||
| 
 | ||||
| // Enable/disable using raw sockets to receive disco traffic. | ||||
| @@ -38,8 +46,9 @@ var debugRawDiscoReads = envknob.RegisterBool("TS_DEBUG_RAW_DISCO") | ||||
| // These are our BPF filters that we use for testing packets. | ||||
| var ( | ||||
| 	magicsockFilterV4 = []bpf.Instruction{ | ||||
| 		// For raw UDPv4 sockets, BPF receives the entire IP packet to | ||||
| 		// inspect. | ||||
| 		// For raw sockets (with ETH_P_IP set), the BPF program | ||||
| 		// receives the entire IPv4 packet, but not the Ethernet | ||||
| 		// header. | ||||
| 
 | ||||
| 		// Disco packets are so small they should never get | ||||
| 		// fragmented, and we don't want to handle reassembly. | ||||
| @@ -53,6 +62,25 @@ var ( | ||||
| 		// Load IP header length into X register. | ||||
| 		bpf.LoadMemShift{Off: 0}, | ||||
| 
 | ||||
| 		// Verify that we have a packet that's big enough to (possibly) | ||||
| 		// contain a disco packet. | ||||
| 		// | ||||
| 		// The length of an IPv4 disco packet is composed of: | ||||
| 		// - 8 bytes for the UDP header | ||||
| 		// - N bytes for the disco packet header | ||||
| 		// | ||||
| 		// bpf will implicitly return 0 ("skip") if attempting an | ||||
| 		// out-of-bounds load, so we can check the length of the packet | ||||
| 		// loading a byte from that offset here. We subtract 1 byte | ||||
| 		// from the offset to ensure that we accept a packet that's | ||||
| 		// exactly the minimum size. | ||||
| 		// | ||||
| 		// We use LoadIndirect; since we loaded the start of the packet's | ||||
| 		// payload into the X register, above, we don't need to add | ||||
| 		// ipv4.HeaderLen to the offset (and this properly handles IPv4 | ||||
| 		// extensions). | ||||
| 		bpf.LoadIndirect{Off: uint32(udpHeaderSize + discoMinHeaderSize - 1), Size: 1}, | ||||
| 
 | ||||
| 		// Get the first 4 bytes of the UDP packet, compare with our magic number | ||||
| 		bpf.LoadIndirect{Off: udpHeaderSize, Size: 4}, | ||||
| 		bpf.JumpIf{Cond: bpf.JumpEqual, Val: discoMagic1, SkipTrue: 0, SkipFalse: 3}, | ||||
| @@ -82,25 +110,24 @@ var ( | ||||
| 	// and thus we'd rather be conservative here and possibly not receive | ||||
| 	// disco packets rather than slow down the system. | ||||
| 	magicsockFilterV6 = []bpf.Instruction{ | ||||
| 		// For raw UDPv6 sockets, BPF receives _only_ the UDP header onwards, not an entire IP packet. | ||||
| 		// | ||||
| 		//    https://stackoverflow.com/questions/24514333/using-bpf-with-sock-dgram-on-linux-machine | ||||
| 		//    https://blog.cloudflare.com/epbf_sockets_hop_distance/ | ||||
| 		// | ||||
| 		// This is especially confusing because this *isn't* true for | ||||
| 		// IPv4; see the following code from the 'ping' utility that | ||||
| 		// corroborates this: | ||||
| 		// | ||||
| 		//    https://github.com/iputils/iputils/blob/1ab5fa/ping/ping.c#L1667-L1676 | ||||
| 		//    https://github.com/iputils/iputils/blob/1ab5fa/ping/ping6_common.c#L933-L941 | ||||
| 		// Do a bounds check to ensure we have enough space for a disco | ||||
| 		// packet; see the comment in the IPv4 BPF program for more | ||||
| 		// details. | ||||
| 		bpf.LoadAbsolute{Off: uint32(ipv6.HeaderLen + udpHeaderSize + discoMinHeaderSize - 1), Size: 1}, | ||||
| 
 | ||||
| 		// Verify that the 'next header' value of the IPv6 packet is | ||||
| 		// UDP, which is what we're expecting; if it's anything else | ||||
| 		// (including extension headers), we skip the packet. | ||||
| 		bpf.LoadAbsolute{Off: 6, Size: 1}, | ||||
| 		bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(ipproto.UDP), SkipTrue: 0, SkipFalse: 5}, | ||||
| 
 | ||||
| 		// Compare with our magic number. Start by loading and | ||||
| 		// comparing the first 4 bytes of the UDP payload. | ||||
| 		bpf.LoadAbsolute{Off: udpHeaderSize, Size: 4}, | ||||
| 		bpf.LoadAbsolute{Off: ipv6.HeaderLen + udpHeaderSize, Size: 4}, | ||||
| 		bpf.JumpIf{Cond: bpf.JumpEqual, Val: discoMagic1, SkipTrue: 0, SkipFalse: 3}, | ||||
| 
 | ||||
| 		// Compare the next 2 bytes | ||||
| 		bpf.LoadAbsolute{Off: udpHeaderSize + 4, Size: 2}, | ||||
| 		bpf.LoadAbsolute{Off: ipv6.HeaderLen + udpHeaderSize + 4, Size: 2}, | ||||
| 		bpf.JumpIf{Cond: bpf.JumpEqual, Val: discoMagic2, SkipTrue: 0, SkipFalse: 1}, | ||||
| 
 | ||||
| 		// Accept the whole packet | ||||
| @@ -140,21 +167,24 @@ func (c *Conn) listenRawDisco(family string) (io.Closer, error) { | ||||
| 	} | ||||
| 
 | ||||
| 	var ( | ||||
| 		network  string | ||||
| 		udpnet   string | ||||
| 		addr     string | ||||
| 		testAddr string | ||||
| 		proto    int | ||||
| 		testAddr netip.AddrPort | ||||
| 		prog     []bpf.Instruction | ||||
| 	) | ||||
| 	switch family { | ||||
| 	case "ip4": | ||||
| 		network = "ip4:17" | ||||
| 		udpnet = "udp4" | ||||
| 		addr = "0.0.0.0" | ||||
| 		testAddr = "127.0.0.1:1" | ||||
| 		proto = ethernetProtoIPv4() | ||||
| 		testAddr = netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), 1) | ||||
| 		prog = magicsockFilterV4 | ||||
| 	case "ip6": | ||||
| 		network = "ip6:17" | ||||
| 		udpnet = "udp6" | ||||
| 		addr = "::" | ||||
| 		testAddr = "[::1]:1" | ||||
| 		proto = ethernetProtoIPv6() | ||||
| 		testAddr = netip.AddrPortFrom(netip.IPv6Loopback(), 1) | ||||
| 		prog = magicsockFilterV6 | ||||
| 	default: | ||||
| 		return nil, fmt.Errorf("unsupported address family %q", family) | ||||
| @@ -165,72 +195,193 @@ func (c *Conn) listenRawDisco(family string) (io.Closer, error) { | ||||
| 		return nil, fmt.Errorf("assembling filter: %w", err) | ||||
| 	} | ||||
| 
 | ||||
| 	pc, err := net.ListenPacket(network, addr) | ||||
| 	sock, err := socket.Socket( | ||||
| 		unix.AF_PACKET, | ||||
| 		unix.SOCK_DGRAM, | ||||
| 		proto, | ||||
| 		"afpacket", | ||||
| 		nil, // no config | ||||
| 	) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("creating packet conn: %w", err) | ||||
| 		return nil, fmt.Errorf("creating AF_PACKET socket: %w", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if err := setBPF(pc, asm); err != nil { | ||||
| 		pc.Close() | ||||
| 	if err := sock.SetBPF(asm); err != nil { | ||||
| 		sock.Close() | ||||
| 		return nil, fmt.Errorf("installing BPF filter: %w", err) | ||||
| 	} | ||||
| 
 | ||||
| 	// If all the above succeeds, we should be ready to receive. Just | ||||
| 	// out of paranoia, check that we do receive a well-formed disco | ||||
| 	// packet. | ||||
| 	tc, err := net.ListenPacket("udp", net.JoinHostPort(addr, "0")) | ||||
| 	tc, err := net.ListenPacket(udpnet, net.JoinHostPort(addr, "0")) | ||||
| 	if err != nil { | ||||
| 		pc.Close() | ||||
| 		sock.Close() | ||||
| 		return nil, fmt.Errorf("creating disco test socket: %w", err) | ||||
| 	} | ||||
| 	defer tc.Close() | ||||
| 	if _, err := tc.(*net.UDPConn).WriteToUDPAddrPort(testDiscoPacket, netip.MustParseAddrPort(testAddr)); err != nil { | ||||
| 		pc.Close() | ||||
| 	if _, err := tc.(*net.UDPConn).WriteToUDPAddrPort(testDiscoPacket, testAddr); err != nil { | ||||
| 		sock.Close() | ||||
| 		return nil, fmt.Errorf("writing disco test packet: %w", err) | ||||
| 	} | ||||
| 	pc.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) | ||||
| 	var buf [1500]byte | ||||
| 
 | ||||
| 	const selfTestTimeout = 100 * time.Millisecond | ||||
| 	if err := sock.SetReadDeadline(time.Now().Add(selfTestTimeout)); err != nil { | ||||
| 		sock.Close() | ||||
| 		return nil, fmt.Errorf("setting socket timeout: %w", err) | ||||
| 	} | ||||
| 
 | ||||
| 	var ( | ||||
| 		ctx = context.Background() | ||||
| 		buf [1500]byte | ||||
| 		pkt packet.Parsed | ||||
| 	) | ||||
| 	for { | ||||
| 		n, _, err := pc.ReadFrom(buf[:]) | ||||
| 		n, _, err := sock.Recvfrom(ctx, buf[:], 0) | ||||
| 		if err != nil { | ||||
| 			pc.Close() | ||||
| 			sock.Close() | ||||
| 			return nil, fmt.Errorf("reading during raw disco self-test: %w", err) | ||||
| 		} | ||||
| 		if n < udpHeaderSize { | ||||
| 
 | ||||
| 		if !decodeDiscoPacket(&pkt, c.discoLogf, buf[:n], family == "ip6") { | ||||
| 			continue | ||||
| 		} | ||||
| 		if !bytes.Equal(buf[udpHeaderSize:n], testDiscoPacket) { | ||||
| 		if payload := pkt.Payload(); !bytes.Equal(payload, testDiscoPacket) { | ||||
| 			c.discoLogf("listenRawDisco: self-test: received mismatched UDP packet of %d bytes", len(payload)) | ||||
| 			continue | ||||
| 		} | ||||
| 		c.logf("[v1] listenRawDisco: self-test passed for %s", family) | ||||
| 		break | ||||
| 	} | ||||
| 	pc.SetReadDeadline(time.Time{}) | ||||
| 	sock.SetReadDeadline(time.Time{}) | ||||
| 
 | ||||
| 	go c.receiveDisco(pc, family == "ip6") | ||||
| 	return pc, nil | ||||
| 	go c.receiveDisco(sock, family == "ip6") | ||||
| 	return sock, nil | ||||
| } | ||||
| 
 | ||||
| func (c *Conn) receiveDisco(pc net.PacketConn, isIPV6 bool) { | ||||
| 	var buf [1500]byte | ||||
| // decodeDiscoPacket decodes a disco packet from buf, using pkt as storage for | ||||
| // the parsed packet. It returns true if the packet is a valid disco packet, | ||||
| // and false otherwise. | ||||
| // | ||||
| // It will log the reason for the packet being invalid to logf; it is the | ||||
| // caller's responsibility to control log verbosity. | ||||
| func decodeDiscoPacket(pkt *packet.Parsed, logf logger.Logf, buf []byte, isIPv6 bool) bool { | ||||
| 	// Do a quick length check before we parse the packet, so we can drop | ||||
| 	// things that we know are too small. | ||||
| 	var minSize int | ||||
| 	if isIPv6 { | ||||
| 		minSize = ipv6.HeaderLen + udpHeaderSize + discoMinHeaderSize | ||||
| 	} else { | ||||
| 		minSize = ipv4.HeaderLen + udpHeaderSize + discoMinHeaderSize | ||||
| 	} | ||||
| 	if len(buf) < minSize { | ||||
| 		logf("decodeDiscoPacket: received packet too small to be a disco packet: %d bytes < %d", len(buf), minSize) | ||||
| 		return false | ||||
| 	} | ||||
| 
 | ||||
| 	// Parse the packet. | ||||
| 	pkt.Decode(buf) | ||||
| 
 | ||||
| 	// Verify that this is a UDP packet. | ||||
| 	if pkt.IPProto != ipproto.UDP { | ||||
| 		logf("decodeDiscoPacket: received non-UDP packet: %d", pkt.IPProto) | ||||
| 		return false | ||||
| 	} | ||||
| 
 | ||||
| 	// Ensure that it's the right version of IP; given how we configure our | ||||
| 	// listening sockets, we shouldn't ever get the wrong one, but it's | ||||
| 	// best to confirm. | ||||
| 	var wantVersion uint8 | ||||
| 	if isIPv6 { | ||||
| 		wantVersion = 6 | ||||
| 	} else { | ||||
| 		wantVersion = 4 | ||||
| 	} | ||||
| 	if pkt.IPVersion != wantVersion { | ||||
| 		logf("decodeDiscoPacket: received mismatched IP version %d (want %d)", pkt.IPVersion, wantVersion) | ||||
| 		return false | ||||
| 	} | ||||
| 
 | ||||
| 	return true | ||||
| } | ||||
| 
 | ||||
| // ethernetProtoIPv4 returns the constant unix.ETH_P_IP, in network byte order. | ||||
| // packet(7) sockets require that the 'protocol' argument be in network byte | ||||
| // order; see: | ||||
| // | ||||
| //	https://man7.org/linux/man-pages/man7/packet.7.html | ||||
| // | ||||
| // Instead of using htons at runtime, we can just hardcode the value here... | ||||
| // but we also have a test that verifies that this is correct. | ||||
| func ethernetProtoIPv4() int { | ||||
| 	if cpu.IsBigEndian { | ||||
| 		return 0x0800 | ||||
| 	} else { | ||||
| 		return 0x0008 | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // ethernetProtoIPv6 returns the constant unix.ETH_P_IPV6, and is otherwise the | ||||
| // same as ethernetProtoIPv4. | ||||
| func ethernetProtoIPv6() int { | ||||
| 	if cpu.IsBigEndian { | ||||
| 		return 0x86dd | ||||
| 	} else { | ||||
| 		return 0xdd86 | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (c *Conn) discoLogf(format string, args ...any) { | ||||
| 	// Enable debug logging if we're debugging raw disco reads or if the | ||||
| 	// magicsock component logs are on. | ||||
| 	if debugRawDiscoReads() { | ||||
| 		c.logf(format, args...) | ||||
| 	} else { | ||||
| 		c.dlogf(format, args...) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (c *Conn) receiveDisco(pc *socket.Conn, isIPV6 bool) { | ||||
| 	ctx := context.Background() | ||||
| 
 | ||||
| 	// Set up our loggers | ||||
| 	var family string | ||||
| 	if isIPV6 { | ||||
| 		family = "ip6" | ||||
| 	} else { | ||||
| 		family = "ip4" | ||||
| 	} | ||||
| 	var ( | ||||
| 		prefix string      = "disco raw " + family + ": " | ||||
| 		logf   logger.Logf = logger.WithPrefix(c.logf, prefix) | ||||
| 		dlogf  logger.Logf = logger.WithPrefix(c.discoLogf, prefix) | ||||
| 	) | ||||
| 
 | ||||
| 	var ( | ||||
| 		buf [1500]byte | ||||
| 		pkt packet.Parsed | ||||
| 	) | ||||
| 	for { | ||||
| 		n, src, err := pc.ReadFrom(buf[:]) | ||||
| 		n, src, err := pc.Recvfrom(ctx, buf[:], 0) | ||||
| 		if debugRawDiscoReads() { | ||||
| 			c.logf("raw disco read from %v = (%v, %v)", src, n, err) | ||||
| 			logf("read from %v = (%v, %v)", src, n, err) | ||||
| 		} | ||||
| 		if errors.Is(err, net.ErrClosed) { | ||||
| 		if err != nil && (errors.Is(err, net.ErrClosed) || err.Error() == "use of closed file") { | ||||
| 			// EOF; no need to print an error | ||||
| 			return | ||||
| 		} else if err != nil { | ||||
| 			c.logf("disco raw reader failed: %v", err) | ||||
| 			logf("reader failed: %v", err) | ||||
| 			return | ||||
| 		} | ||||
| 		if n < udpHeaderSize { | ||||
| 			// Too small to be a valid UDP datagram, drop. | ||||
| 
 | ||||
| 		if !decodeDiscoPacket(&pkt, dlogf, buf[:n], isIPV6) { | ||||
| 			// callee logged | ||||
| 			continue | ||||
| 		} | ||||
| 
 | ||||
| 		dstPort := binary.BigEndian.Uint16(buf[2:4]) | ||||
| 		dstPort := pkt.Dst.Port() | ||||
| 		if dstPort == 0 { | ||||
| 			c.logf("[unexpected] disco raw: received packet for port 0") | ||||
| 			logf("[unexpected] received packet for port 0") | ||||
| 		} | ||||
| 
 | ||||
| 		var acceptPort uint16 | ||||
| @@ -242,61 +393,33 @@ func (c *Conn) receiveDisco(pc net.PacketConn, isIPV6 bool) { | ||||
| 		if acceptPort == 0 { | ||||
| 			// This should only typically happen if the receiving address family | ||||
| 			// was recently disabled. | ||||
| 			c.dlogf("[v1] disco raw: dropping packet for port %d as acceptPort=0", dstPort) | ||||
| 			dlogf("[v1] dropping packet for port %d as acceptPort=0", dstPort) | ||||
| 			continue | ||||
| 		} | ||||
| 
 | ||||
| 		// If the packet isn't destined for our local port, then we | ||||
| 		// should drop it since it might be for another Tailscale | ||||
| 		// process on the same machine, or NATed to a different machine | ||||
| 		// if this is a router, etc. | ||||
| 		// | ||||
| 		// We get the local port to compare against inside the receive | ||||
| 		// loop; we can't cache this beforehand because it can change | ||||
| 		// if/when we rebind. | ||||
| 		if dstPort != acceptPort { | ||||
| 			c.dlogf("[v1] disco raw: dropping packet for port %d", dstPort) | ||||
| 			dlogf("[v1] dropping packet for port %d that isn't our local port", dstPort) | ||||
| 			continue | ||||
| 		} | ||||
| 
 | ||||
| 		srcIP, ok := netip.AddrFromSlice(src.(*net.IPAddr).IP) | ||||
| 		if !ok { | ||||
| 			c.logf("[unexpected] PacketConn.ReadFrom returned not-an-IP %v in from", src) | ||||
| 			continue | ||||
| 		} | ||||
| 		srcPort := binary.BigEndian.Uint16(buf[:2]) | ||||
| 
 | ||||
| 		if srcIP.Is4() { | ||||
| 			metricRecvDiscoPacketIPv4.Add(1) | ||||
| 		} else { | ||||
| 		if isIPV6 { | ||||
| 			metricRecvDiscoPacketIPv6.Add(1) | ||||
| 		} else { | ||||
| 			metricRecvDiscoPacketIPv4.Add(1) | ||||
| 		} | ||||
| 
 | ||||
| 		c.handleDiscoMessage(buf[udpHeaderSize:n], netip.AddrPortFrom(srcIP, srcPort), key.NodePublic{}, discoRXPathRawSocket) | ||||
| 		c.handleDiscoMessage(pkt.Payload(), pkt.Src, key.NodePublic{}, discoRXPathRawSocket) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // setBPF installs filter as the BPF filter on conn. | ||||
| // Ideally we would just use SetBPF as implemented in x/net/ipv4, | ||||
| // but x/net/ipv6 doesn't implement it. And once you've written | ||||
| // this code once, it turns out to be address family agnostic, so | ||||
| // we might as well use it on both and get to use a net.PacketConn | ||||
| // directly for both families instead of being stuck with | ||||
| // different types. | ||||
| func setBPF(conn net.PacketConn, filter []bpf.RawInstruction) error { | ||||
| 	sc, err := conn.(*net.IPConn).SyscallConn() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	prog := &unix.SockFprog{ | ||||
| 		Len:    uint16(len(filter)), | ||||
| 		Filter: (*unix.SockFilter)(unsafe.Pointer(&filter[0])), | ||||
| 	} | ||||
| 	var setErr error | ||||
| 	err = sc.Control(func(fd uintptr) { | ||||
| 		setErr = unix.SetsockoptSockFprog(int(fd), unix.SOL_SOCKET, unix.SO_ATTACH_FILTER, prog) | ||||
| 	}) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if setErr != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // trySetSocketBuffer attempts to set SO_SNDBUFFORCE and SO_RECVBUFFORCE which | ||||
| // can overcome the limit of net.core.{r,w}mem_max, but require CAP_NET_ADMIN. | ||||
| // It falls back to the portable implementation if that fails, which may be | ||||
|   | ||||
							
								
								
									
										150
									
								
								wgengine/magicsock/magicsock_linux_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										150
									
								
								wgengine/magicsock/magicsock_linux_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,150 @@ | ||||
| // Copyright (c) Tailscale Inc & AUTHORS | ||||
| // SPDX-License-Identifier: BSD-3-Clause | ||||
| 
 | ||||
| package magicsock | ||||
| 
 | ||||
| import ( | ||||
| 	"encoding/binary" | ||||
| 	"net" | ||||
| 	"net/netip" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"golang.org/x/net/ipv4" | ||||
| 	"golang.org/x/net/ipv6" | ||||
| 	"golang.org/x/sys/cpu" | ||||
| 	"golang.org/x/sys/unix" | ||||
| 	"tailscale.com/disco" | ||||
| 	"tailscale.com/net/packet" | ||||
| 	"tailscale.com/types/ipproto" | ||||
| ) | ||||
| 
 | ||||
| func TestDecodeDiscoPacket(t *testing.T) { | ||||
| 	mk4 := func(proto ipproto.Proto, src, dst netip.Addr, data []byte) []byte { | ||||
| 		if !src.Is4() || !dst.Is4() { | ||||
| 			panic("not an IPv4 address") | ||||
| 		} | ||||
| 		iph := &ipv4.Header{ | ||||
| 			Version:  ipv4.Version, | ||||
| 			Len:      ipv4.HeaderLen, | ||||
| 			TotalLen: ipv4.HeaderLen + len(data), | ||||
| 			TTL:      64, | ||||
| 			Protocol: int(proto), | ||||
| 			Src:      net.IP(src.AsSlice()), | ||||
| 			Dst:      net.IP(dst.AsSlice()), | ||||
| 		} | ||||
| 		hdr, err := iph.Marshal() | ||||
| 		if err != nil { | ||||
| 			panic(err) | ||||
| 		} | ||||
| 		return append(hdr, data...) | ||||
| 	} | ||||
| 	mk6 := func(proto ipproto.Proto, src, dst netip.Addr, data []byte) []byte { | ||||
| 		if !src.Is6() || !dst.Is6() { | ||||
| 			panic("not an IPv6 address") | ||||
| 		} | ||||
| 		// The ipv6 package doesn't have a Marshal method, so just do | ||||
| 		// the most basic thing and construct the header manually. | ||||
| 		buf := make([]byte, ipv6.HeaderLen, ipv6.HeaderLen+len(data)) | ||||
| 		buf[0] = 6 << 4 // version | ||||
| 		binary.BigEndian.PutUint16(buf[4:6], uint16(len(data))) | ||||
| 		buf[6] = byte(proto) | ||||
| 		copy(buf[8:24], src.AsSlice()) | ||||
| 		copy(buf[24:40], dst.AsSlice()) | ||||
| 		return append(buf, data...) | ||||
| 	} | ||||
| 
 | ||||
| 	mkUDP := func(srcPort, dstPort uint16, data []byte) []byte { | ||||
| 		udp := make([]byte, 8, 8+len(data)) | ||||
| 		binary.BigEndian.PutUint16(udp[0:2], srcPort) | ||||
| 		binary.BigEndian.PutUint16(udp[2:4], dstPort) | ||||
| 		binary.BigEndian.PutUint16(udp[4:6], uint16(8+len(data))) | ||||
| 		return append(udp, data...) | ||||
| 	} | ||||
| 	mkUDP4 := func(src, dst netip.AddrPort, data []byte) []byte { | ||||
| 		return mk4(ipproto.UDP, src.Addr(), dst.Addr(), mkUDP(src.Port(), dst.Port(), data)) | ||||
| 	} | ||||
| 	mkUDP6 := func(src, dst netip.AddrPort, data []byte) []byte { | ||||
| 		return mk6(ipproto.UDP, src.Addr(), dst.Addr(), mkUDP(src.Port(), dst.Port(), data)) | ||||
| 	} | ||||
| 
 | ||||
| 	ip4 := netip.MustParseAddrPort("127.0.0.10:12345") | ||||
| 	ip4_2 := netip.MustParseAddrPort("127.0.0.99:54321") | ||||
| 	ip6 := netip.MustParseAddrPort("[::1]:12345") | ||||
| 
 | ||||
| 	testCases := []struct { | ||||
| 		name string | ||||
| 		in   []byte | ||||
| 		is6  bool | ||||
| 		want bool | ||||
| 	}{ | ||||
| 		{ | ||||
| 			name: "too_short_4", | ||||
| 			in:   mkUDP4(ip4, ip4_2, append([]byte(disco.Magic), 0x00, 0x00)), | ||||
| 			is6:  false, | ||||
| 			want: false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "too_short_6", | ||||
| 			in:   mkUDP6(ip6, ip6, append([]byte(disco.Magic), 0x00, 0x00)), | ||||
| 			is6:  true, | ||||
| 			want: false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "valid_4", | ||||
| 			in:   mkUDP4(ip4, ip4_2, testDiscoPacket), | ||||
| 			is6:  false, | ||||
| 			want: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "valid_6", | ||||
| 			in:   mkUDP6(ip6, ip6, testDiscoPacket), | ||||
| 			is6:  true, | ||||
| 			want: true, | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	for _, tc := range testCases { | ||||
| 		t.Run(tc.name, func(t *testing.T) { | ||||
| 			var pkt packet.Parsed | ||||
| 			got := decodeDiscoPacket(&pkt, t.Logf, tc.in, tc.is6) | ||||
| 			if got != tc.want { | ||||
| 				t.Errorf("got %v; want %v", got, tc.want) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestEthernetProto(t *testing.T) { | ||||
| 	htons := func(x uint16) int { | ||||
| 		// Network byte order is big-endian; write the value as | ||||
| 		// big-endian to a byte slice and read it back in the native | ||||
| 		// endian-ness. This is a no-op on a big-endian platform and a | ||||
| 		// byte swap on a little-endian platform. | ||||
| 		var b [2]byte | ||||
| 		binary.BigEndian.PutUint16(b[:], x) | ||||
| 		return int(binary.NativeEndian.Uint16(b[:])) | ||||
| 	} | ||||
| 
 | ||||
| 	if v4 := ethernetProtoIPv4(); v4 != htons(unix.ETH_P_IP) { | ||||
| 		t.Errorf("ethernetProtoIPv4 = 0x%04x; want 0x%04x", v4, htons(unix.ETH_P_IP)) | ||||
| 	} | ||||
| 	if v6 := ethernetProtoIPv6(); v6 != htons(unix.ETH_P_IPV6) { | ||||
| 		t.Errorf("ethernetProtoIPv6 = 0x%04x; want 0x%04x", v6, htons(unix.ETH_P_IPV6)) | ||||
| 	} | ||||
| 
 | ||||
| 	// As a way to verify that the htons function is working correctly, | ||||
| 	// assert that the ETH_P_IP value returned from our function matches | ||||
| 	// the value defined in the unix package based on whether the host is | ||||
| 	// big-endian (network byte order) or little-endian. | ||||
| 	if cpu.IsBigEndian { | ||||
| 		if v4 := ethernetProtoIPv4(); v4 != unix.ETH_P_IP { | ||||
| 			t.Errorf("ethernetProtoIPv4 = 0x%04x; want 0x%04x", v4, unix.ETH_P_IP) | ||||
| 		} | ||||
| 	} else { | ||||
| 		if v4 := ethernetProtoIPv4(); v4 == unix.ETH_P_IP { | ||||
| 			t.Errorf("ethernetProtoIPv4 = 0x%04x; want 0x%04x", v4, htons(unix.ETH_P_IP)) | ||||
| 		} else { | ||||
| 			t.Logf("ethernetProtoIPv4 = 0x%04x, correctly different from 0x%04x", v4, unix.ETH_P_IP) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
		Reference in New Issue
	
	Block a user
	 Andrew Dunham
					Andrew Dunham