tailscale/wgengine/packet/packet.go

366 lines
8.8 KiB
Go

// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package packet
import (
"encoding/binary"
"fmt"
"log"
"net"
"strings"
"tailscale.com/types/strbuilder"
)
type IPProto int
const (
Junk IPProto = iota
Fragment
ICMP
UDP
TCP
)
// RFC1858: prevent overlapping fragment attacks.
const minFrag = 60 + 20 // max IPv4 header + basic TCP header
func (p IPProto) String() string {
switch p {
case Fragment:
return "Frag"
case ICMP:
return "ICMP"
case UDP:
return "UDP"
case TCP:
return "TCP"
default:
return "Junk"
}
}
// IP is an IPv4 address.
type IP uint32
// NewIP converts a standard library IP address into an IP.
// It panics if b is not an IPv4 address.
func NewIP(b net.IP) IP {
b4 := b.To4()
if b4 == nil {
panic(fmt.Sprintf("To4(%v) failed", b))
}
return IP(binary.BigEndian.Uint32(b4))
}
func (ip IP) String() string {
return fmt.Sprintf("%d.%d.%d.%d", byte(ip>>24), byte(ip>>16), byte(ip>>8), byte(ip))
}
// ICMP types.
const (
ICMPEchoReply = 0x00
ICMPEchoRequest = 0x08
ICMPUnreachable = 0x03
ICMPTimeExceeded = 0x0b
)
const (
TCPSyn = 0x02
TCPAck = 0x10
TCPSynAck = TCPSyn | TCPAck
)
type QDecode struct {
b []byte // Packet buffer that this decodes
subofs int // byte offset of IP subprotocol
IPProto IPProto // IP subprotocol (UDP, TCP, etc)
SrcIP IP // IP source address
DstIP IP // IP destination address
SrcPort uint16 // TCP/UDP source port
DstPort uint16 // TCP/UDP destination port
TCPFlags uint8 // TCP flags (SYN, ACK, etc)
}
func (q *QDecode) String() string {
if q.IPProto == Junk {
return "Junk{}"
}
sb := strbuilder.Get()
sb.WriteString(q.IPProto.String())
sb.WriteByte('{')
writeIPPort(sb, q.SrcIP, q.SrcPort)
sb.WriteString(" > ")
writeIPPort(sb, q.DstIP, q.DstPort)
sb.WriteByte('}')
return sb.String()
}
func writeIPPort(sb *strbuilder.Builder, ip IP, port uint16) {
sb.WriteUint(uint64(byte(ip >> 24)))
sb.WriteByte('.')
sb.WriteUint(uint64(byte(ip >> 16)))
sb.WriteByte('.')
sb.WriteUint(uint64(byte(ip >> 8)))
sb.WriteByte('.')
sb.WriteUint(uint64(byte(ip)))
sb.WriteByte(':')
sb.WriteUint(uint64(port))
}
// based on https://tools.ietf.org/html/rfc1071
func ipChecksum(b []byte) uint16 {
var ac uint32
i := 0
n := len(b)
for n >= 2 {
ac += uint32(binary.BigEndian.Uint16(b[i : i+2]))
n -= 2
i += 2
}
if n == 1 {
ac += uint32(b[i]) << 8
}
for (ac >> 16) > 0 {
ac = (ac >> 16) + (ac & 0xffff)
}
return uint16(^ac)
}
var put16 = binary.BigEndian.PutUint16
var put32 = binary.BigEndian.PutUint32
// GenICMP returns the bytes of an ICMP packet.
// If payload is too short or too long, it returns nil.
func GenICMP(srcIP, dstIP IP, ipid uint16, icmpType, icmpCode uint8, payload []byte) []byte {
if len(payload) < 4 {
return nil
}
if len(payload) > 65535-24 {
return nil
}
sz := 24 + len(payload)
out := make([]byte, 24+len(payload))
out[0] = 0x45 // IPv4, 20-byte header
out[1] = 0x00 // DHCP, ECN
put16(out[2:4], uint16(sz))
put16(out[4:6], ipid)
put16(out[6:8], 0) // flags, offset
out[8] = 64 // TTL
out[9] = 0x01 // ICMPv4
// out[10:12] = 0x00 // blank IP header checksum
put32(out[12:16], uint32(srcIP))
put32(out[16:20], uint32(dstIP))
out[20] = icmpType
out[21] = icmpCode
//out[22:24] = 0x00 // blank ICMP checksum
copy(out[24:], payload)
put16(out[10:12], ipChecksum(out[0:20]))
put16(out[22:24], ipChecksum(out))
return out
}
// An extremely simple packet decoder for basic IPv4 packet types.
// It extracts only the subprotocol id, IP addresses, and (if any) ports,
// and shouldn't need any memory allocation.
func (q *QDecode) Decode(b []byte) {
q.b = nil
if len(b) < 20 {
q.IPProto = Junk
return
}
// Check that it's IPv4.
// TODO(apenwarr): consider IPv6 support
if ((b[0] & 0xF0) >> 4) != 4 {
q.IPProto = Junk
return
}
n := int(binary.BigEndian.Uint16(b[2:4]))
if len(b) < n {
// Packet was cut off before full IPv4 length.
q.IPProto = Junk
return
}
// If it's valid IPv4, then the IP addresses are valid
q.SrcIP = IP(binary.BigEndian.Uint32(b[12:16]))
q.DstIP = IP(binary.BigEndian.Uint32(b[16:20]))
q.subofs = int((b[0] & 0x0F) * 4)
sub := b[q.subofs:]
// We don't care much about IP fragmentation, except insofar as it's
// used for firewall bypass attacks. The trick is make the first
// fragment of a TCP or UDP packet so short that it doesn't fit
// the TCP or UDP header, so we can't read the port, in hope that
// it'll sneak past. Then subsequent fragments fill it in, but we're
// missing the first part of the header, so we can't read that either.
//
// A "perfectly correct" implementation would have to reassemble
// fragments before deciding what to do. But the truth is there's
// zero reason to send such a short first fragment, so we can treat
// it as Junk. We can also treat any subsequent fragment that starts
// at such a low offset as Junk.
fragFlags := binary.BigEndian.Uint16(b[6:8])
moreFrags := (fragFlags & 0x20) != 0
fragOfs := fragFlags & 0x1FFF
if fragOfs == 0 {
// This is the first fragment
if moreFrags && len(sub) < minFrag {
// Suspiciously short first fragment, dump it.
log.Printf("junk1!\n")
q.IPProto = Junk
return
}
// otherwise, this is either non-fragmented (the usual case)
// or a big enough initial fragment that we can read the
// whole subprotocol header.
proto := b[9]
switch proto {
case 1: // ICMPv4
if len(sub) < 8 {
q.IPProto = Junk
return
}
q.IPProto = ICMP
q.SrcPort = 0
q.DstPort = 0
q.b = b
return
case 6: // TCP
if len(sub) < 20 {
q.IPProto = Junk
return
}
q.IPProto = TCP
q.SrcPort = binary.BigEndian.Uint16(sub[0:2])
q.DstPort = binary.BigEndian.Uint16(sub[2:4])
q.TCPFlags = sub[13] & 0x3F
q.b = b
return
case 17: // UDP
if len(sub) < 8 {
q.IPProto = Junk
return
}
q.IPProto = UDP
q.SrcPort = binary.BigEndian.Uint16(sub[0:2])
q.DstPort = binary.BigEndian.Uint16(sub[2:4])
q.b = b
return
default:
q.IPProto = Junk
return
}
} else {
// This is a fragment other than the first one.
if fragOfs < minFrag {
// First frag was suspiciously short, so we can't
// trust the followup either.
q.IPProto = Junk
return
}
// otherwise, we have to permit the fragment to slide through.
// Second and later fragments don't have sub-headers.
// Ideally, we would drop fragments that we can't identify,
// but that would require statefulness. Anyway, receivers'
// kernels know to drop fragments where the initial fragment
// doesn't arrive.
q.IPProto = Fragment
return
}
}
// Returns a subset of the IP subprotocol section.
func (q *QDecode) Sub(begin, n int) []byte {
return q.b[q.subofs+begin : q.subofs+begin+n]
}
// Trim trims the buffer to its IPv4 length.
// Sometimes packets arrive from an interface with extra bytes on the end.
// This removes them.
func (q *QDecode) Trim() []byte {
n := binary.BigEndian.Uint16(q.b[2:4])
return q.b[:n]
}
// IsTCPSyn reports whether q is a TCP SYN packet (i.e. the
// first packet in a new connection).
func (q *QDecode) IsTCPSyn() bool {
return (q.TCPFlags & TCPSynAck) == TCPSyn
}
// IsError reports whether q is an IPv4 ICMP "Error" packet.
func (q *QDecode) IsError() bool {
if q.IPProto == ICMP && len(q.b) >= q.subofs+8 {
switch q.b[q.subofs] {
case ICMPUnreachable, ICMPTimeExceeded:
return true
}
}
return false
}
// IsEchoRequest reports whether q is an IPv4 ICMP Echo Request.
func (q *QDecode) IsEchoRequest() bool {
if q.IPProto == ICMP && len(q.b) >= q.subofs+8 {
return q.b[q.subofs] == ICMPEchoRequest && q.b[q.subofs+1] == 0
}
return false
}
// IsEchoRequest reports whether q is an IPv4 ICMP Echo Response.
func (q *QDecode) IsEchoResponse() bool {
if q.IPProto == ICMP && len(q.b) >= q.subofs+8 {
return q.b[q.subofs] == ICMPEchoReply && q.b[q.subofs+1] == 0
}
return false
}
// EchoResponse returns an IPv4 ICMP echo reply to the request in q.
func (q *QDecode) EchoRespond() []byte {
icmpid := binary.BigEndian.Uint16(q.Sub(4, 2))
b := q.Trim()
return GenICMP(q.DstIP, q.SrcIP, icmpid, ICMPEchoReply, 0, b[q.subofs+4:])
}
func Hexdump(b []byte) string {
out := new(strings.Builder)
for i := 0; i < len(b); i += 16 {
if i > 0 {
fmt.Fprintf(out, "\n")
}
fmt.Fprintf(out, " %04x ", i)
j := 0
for ; j < 16 && i+j < len(b); j++ {
if j == 8 {
fmt.Fprintf(out, " ")
}
fmt.Fprintf(out, "%02x ", b[i+j])
}
for ; j < 16; j++ {
if j == 8 {
fmt.Fprintf(out, " ")
}
fmt.Fprintf(out, " ")
}
fmt.Fprintf(out, " ")
for j = 0; j < 16 && i+j < len(b); j++ {
if b[i+j] >= 32 && b[i+j] < 128 {
fmt.Fprintf(out, "%c", b[i+j])
} else {
fmt.Fprintf(out, ".")
}
}
}
return out.String()
}