diff --git a/wgengine/filter/filter.go b/wgengine/filter/filter.go index 2388f0ac8..eac00e56f 100644 --- a/wgengine/filter/filter.go +++ b/wgengine/filter/filter.go @@ -77,8 +77,8 @@ func (r Response) String() string { ) type tuple struct { - SrcIP packet.IP - DstIP packet.IP + SrcIP packet.IP4 + DstIP packet.IP4 SrcPort uint16 DstPort uint16 } @@ -412,7 +412,7 @@ func omitDropLogging(p *packet.ParsedPacket, dir direction) bool { // it doesn't know about, so parse it out ourselves if needed. ipProto := p.IPProto if ipProto == 0 && len(b) > 8 { - ipProto = packet.IPProto(b[9]) + ipProto = packet.IP4Proto(b[9]) } // Omit logging about outgoing IGMP. if ipProto == packet.IGMP { diff --git a/wgengine/filter/filter_test.go b/wgengine/filter/filter_test.go index b15b2729c..303ab9747 100644 --- a/wgengine/filter/filter_test.go +++ b/wgengine/filter/filter_test.go @@ -16,17 +16,13 @@ "tailscale.com/wgengine/packet" ) -// Type aliases only in test code: (but ideally nowhere) -type ParsedPacket = packet.ParsedPacket -type IP = packet.IP - var Unknown = packet.Unknown var ICMP = packet.ICMP var TCP = packet.TCP var UDP = packet.UDP var Fragment = packet.Fragment -func nets(ips []IP) []Net { +func nets(ips []packet.IP4) []Net { out := make([]Net, 0, len(ips)) for _, ip := range ips { out = append(out, Net{ip, Netmask(32)}) @@ -34,35 +30,35 @@ func nets(ips []IP) []Net { return out } -func ippr(ip IP, start, end uint16) []NetPortRange { +func ippr(ip packet.IP4, start, end uint16) []NetPortRange { return []NetPortRange{ NetPortRange{Net{ip, Netmask(32)}, PortRange{start, end}}, } } -func netpr(ip IP, bits int, start, end uint16) []NetPortRange { +func netpr(ip packet.IP4, bits int, start, end uint16) []NetPortRange { return []NetPortRange{ NetPortRange{Net{ip, Netmask(bits)}, PortRange{start, end}}, } } var matches = Matches{ - {Srcs: nets([]IP{0x08010101, 0x08020202}), Dsts: []NetPortRange{ + {Srcs: nets([]packet.IP4{0x08010101, 0x08020202}), Dsts: []NetPortRange{ NetPortRange{Net{0x01020304, Netmask(32)}, PortRange{22, 22}}, NetPortRange{Net{0x05060708, Netmask(32)}, PortRange{23, 24}}, }}, - {Srcs: nets([]IP{0x08010101, 0x08020202}), Dsts: ippr(0x05060708, 27, 28)}, - {Srcs: nets([]IP{0x02020202}), Dsts: ippr(0x08010101, 22, 22)}, + {Srcs: nets([]packet.IP4{0x08010101, 0x08020202}), Dsts: ippr(0x05060708, 27, 28)}, + {Srcs: nets([]packet.IP4{0x02020202}), Dsts: ippr(0x08010101, 22, 22)}, {Srcs: []Net{NetAny}, Dsts: ippr(0x647a6232, 0, 65535)}, {Srcs: []Net{NetAny}, Dsts: netpr(0, 0, 443, 443)}, - {Srcs: nets([]IP{0x99010101, 0x99010102, 0x99030303}), Dsts: ippr(0x01020304, 999, 999)}, + {Srcs: nets([]packet.IP4{0x99010101, 0x99010102, 0x99030303}), Dsts: ippr(0x01020304, 999, 999)}, } func newFilter(logf logger.Logf) *Filter { // Expects traffic to 100.122.98.50, 1.2.3.4, 5.6.7.8, // 102.102.102.102, 119.119.119.119, 8.1.0.0/16 - localNets := nets([]IP{0x647a6232, 0x01020304, 0x05060708, 0x66666666, 0x77777777}) - localNets = append(localNets, Net{IP(0x08010000), Netmask(16)}) + localNets := nets([]packet.IP4{0x647a6232, 0x01020304, 0x05060708, 0x66666666, 0x77777777}) + localNets = append(localNets, Net{packet.IP4(0x08010000), Netmask(16)}) return New(matches, localNets, nil, logf) } @@ -87,7 +83,7 @@ func TestFilter(t *testing.T) { type InOut struct { want Response - p ParsedPacket + p packet.ParsedPacket } tests := []InOut{ // Basic @@ -147,7 +143,7 @@ func TestNoAllocs(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { got := int(testing.AllocsPerRun(1000, func() { - q := &ParsedPacket{} + q := &packet.ParsedPacket{} q.Decode(test.packet) if test.in { acl.RunIn(q, 0) @@ -170,7 +166,7 @@ func TestParseIP(t *testing.T) { want Net wantErr string }{ - {"8.8.8.8", 24, Net{IP: packet.NewIP(net.ParseIP("8.8.8.8")), Mask: packet.NewIP(net.ParseIP("255.255.255.0"))}, ""}, + {"8.8.8.8", 24, Net{IP: packet.NewIP4(net.ParseIP("8.8.8.8")), Mask: packet.NewIP4(net.ParseIP("255.255.255.0"))}, ""}, {"8.8.8.8", 33, Net{}, `invalid CIDR size 33 for host "8.8.8.8"`}, {"8.8.8.8", -1, Net{}, `invalid CIDR size -1 for host "8.8.8.8"`}, {"0.0.0.0", 24, Net{}, `ports="0.0.0.0": to allow all IP addresses, use *:port, not 0.0.0.0:port`}, @@ -220,7 +216,7 @@ func BenchmarkFilter(b *testing.B) { for _, bench := range benches { b.Run(bench.name, func(b *testing.B) { for i := 0; i < b.N; i++ { - q := &ParsedPacket{} + q := &packet.ParsedPacket{} q.Decode(bench.packet) // This branch seems to have no measurable impact on performance. if bench.in { @@ -249,7 +245,7 @@ func TestPreFilter(t *testing.T) { } f := NewAllowNone(t.Logf) for _, testPacket := range packets { - p := &ParsedPacket{} + p := &packet.ParsedPacket{} p.Decode(testPacket.b) got := f.pre(p, LogDrops|LogAccepts, in) if got != testPacket.want { @@ -258,8 +254,8 @@ func TestPreFilter(t *testing.T) { } } -func parsed(proto packet.IPProto, src, dst packet.IP, sport, dport uint16) ParsedPacket { - return ParsedPacket{ +func parsed(proto packet.IP4Proto, src, dst packet.IP4, sport, dport uint16) packet.ParsedPacket { + return packet.ParsedPacket{ IPProto: proto, SrcIP: src, DstIP: dst, @@ -271,7 +267,7 @@ func parsed(proto packet.IPProto, src, dst packet.IP, sport, dport uint16) Parse // rawpacket generates a packet with given source and destination ports and IPs // and resizes the header to trimLength if it is nonzero. -func rawpacket(proto packet.IPProto, src, dst packet.IP, sport, dport uint16, trimLength int) []byte { +func rawpacket(proto packet.IP4Proto, src, dst packet.IP4, sport, dport uint16, trimLength int) []byte { var headerLength int switch proto { @@ -325,8 +321,8 @@ func rawpacket(proto packet.IPProto, src, dst packet.IP, sport, dport uint16, tr } // rawdefault calls rawpacket with default ports and IPs. -func rawdefault(proto packet.IPProto, trimLength int) []byte { - ip := IP(0x08080808) // 8.8.8.8 +func rawdefault(proto packet.IP4Proto, trimLength int) []byte { + ip := packet.IP4(0x08080808) // 8.8.8.8 port := uint16(53) return rawpacket(proto, ip, ip, port, port, trimLength) } @@ -381,19 +377,19 @@ func TestOmitDropLogging(t *testing.T) { }, { name: "v4_multicast_out_low", - pkt: &packet.ParsedPacket{IPVersion: 4, DstIP: packet.NewIP(net.ParseIP("224.0.0.0"))}, + pkt: &packet.ParsedPacket{IPVersion: 4, DstIP: packet.NewIP4(net.ParseIP("224.0.0.0"))}, dir: out, want: true, }, { name: "v4_multicast_out_high", - pkt: &packet.ParsedPacket{IPVersion: 4, DstIP: packet.NewIP(net.ParseIP("239.255.255.255"))}, + pkt: &packet.ParsedPacket{IPVersion: 4, DstIP: packet.NewIP4(net.ParseIP("239.255.255.255"))}, dir: out, want: true, }, { name: "v4_link_local_unicast", - pkt: &packet.ParsedPacket{IPVersion: 4, DstIP: packet.NewIP(net.ParseIP("169.254.1.2"))}, + pkt: &packet.ParsedPacket{IPVersion: 4, DstIP: packet.NewIP4(net.ParseIP("169.254.1.2"))}, dir: out, want: true, }, diff --git a/wgengine/filter/match.go b/wgengine/filter/match.go index 2632405c8..e2c170911 100644 --- a/wgengine/filter/match.go +++ b/wgengine/filter/match.go @@ -13,16 +13,16 @@ "tailscale.com/wgengine/packet" ) -func NewIP(ip net.IP) packet.IP { - return packet.NewIP(ip) +func NewIP(ip net.IP) packet.IP4 { + return packet.NewIP4(ip) } type Net struct { - IP packet.IP - Mask packet.IP + IP packet.IP4 + Mask packet.IP4 } -func (n Net) Includes(ip packet.IP) bool { +func (n Net) Includes(ip packet.IP4) bool { return (n.IP & n.Mask) == (ip & n.Mask) } @@ -42,11 +42,11 @@ func (n Net) String() string { } var NetAny = Net{0, 0} -var NetNone = Net{^packet.IP(0), ^packet.IP(0)} +var NetNone = Net{^packet.IP4(0), ^packet.IP4(0)} -func Netmask(bits int) packet.IP { +func Netmask(bits int) packet.IP4 { b := ^uint32((1 << (32 - bits)) - 1) - return packet.IP(b) + return packet.IP4(b) } type PortRange struct { @@ -124,7 +124,7 @@ func (m Matches) Clone() (res Matches) { return res } -func ipInList(ip packet.IP, netlist []Net) bool { +func ipInList(ip packet.IP4, netlist []Net) bool { for _, net := range netlist { if net.Includes(ip) { return true diff --git a/wgengine/packet/doc.go b/wgengine/packet/doc.go new file mode 100644 index 000000000..40dc87a14 --- /dev/null +++ b/wgengine/packet/doc.go @@ -0,0 +1,16 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package packet contains packet parsing and marshaling utilities. +// +// ParsedPacket provides allocation-free minimal packet header +// decoding, for use in packet filtering. The other types in the +// package are for constructing and marshaling packets into []bytes. +// +// To support allocation-free parsing, this package defines IPv4 and +// IPv6 address types. You should prefer to use netaddr's types, +// except where you absolutely need allocation-free IP handling +// (i.e. in the tunnel datapath) and are willing to implement all +// codepaths and data structures twice, once per IP family. +package packet diff --git a/wgengine/packet/icmp.go b/wgengine/packet/icmp4.go similarity index 55% rename from wgengine/packet/icmp.go rename to wgengine/packet/icmp4.go index c4cb7b149..c9e428793 100644 --- a/wgengine/packet/icmp.go +++ b/wgengine/packet/icmp4.go @@ -4,41 +4,41 @@ package packet -type ICMPType uint8 +type ICMP4Type uint8 const ( - ICMPEchoReply ICMPType = 0x00 - ICMPEchoRequest ICMPType = 0x08 - ICMPUnreachable ICMPType = 0x03 - ICMPTimeExceeded ICMPType = 0x0b + ICMP4EchoReply ICMP4Type = 0x00 + ICMP4EchoRequest ICMP4Type = 0x08 + ICMP4Unreachable ICMP4Type = 0x03 + ICMP4TimeExceeded ICMP4Type = 0x0b ) -func (t ICMPType) String() string { +func (t ICMP4Type) String() string { switch t { - case ICMPEchoReply: + case ICMP4EchoReply: return "EchoReply" - case ICMPEchoRequest: + case ICMP4EchoRequest: return "EchoRequest" - case ICMPUnreachable: + case ICMP4Unreachable: return "Unreachable" - case ICMPTimeExceeded: + case ICMP4TimeExceeded: return "TimeExceeded" default: return "Unknown" } } -type ICMPCode uint8 +type ICMP4Code uint8 const ( - ICMPNoCode ICMPCode = 0 + ICMP4NoCode ICMP4Code = 0 ) // ICMPHeader represents an ICMP packet header. -type ICMPHeader struct { - IPHeader - Type ICMPType - Code ICMPCode +type ICMP4Header struct { + IP4Header + Type ICMP4Type + Code ICMP4Code } const ( @@ -47,11 +47,11 @@ type ICMPHeader struct { icmpAllHeadersLength = ipHeaderLength + icmpHeaderLength ) -func (ICMPHeader) Len() int { +func (ICMP4Header) Len() int { return icmpAllHeadersLength } -func (h ICMPHeader) Marshal(buf []byte) error { +func (h ICMP4Header) Marshal(buf []byte) error { if len(buf) < icmpAllHeadersLength { return errSmallBuffer } @@ -64,15 +64,17 @@ func (h ICMPHeader) Marshal(buf []byte) error { buf[20] = uint8(h.Type) buf[21] = uint8(h.Code) - h.IPHeader.Marshal(buf) + h.IP4Header.Marshal(buf) put16(buf[22:24], ipChecksum(buf)) return nil } -func (h *ICMPHeader) ToResponse() { - h.Type = ICMPEchoReply - h.Code = ICMPNoCode - h.IPHeader.ToResponse() +func (h *ICMP4Header) ToResponse() { + // TODO: this doesn't implement ToResponse correctly, as it + // assumes the ICMP request type. + h.Type = ICMP4EchoReply + h.Code = ICMP4NoCode + h.IP4Header.ToResponse() } diff --git a/wgengine/packet/ip.go b/wgengine/packet/ip4.go similarity index 66% rename from wgengine/packet/ip.go rename to wgengine/packet/ip4.go index f26ce2626..09ebc3205 100644 --- a/wgengine/packet/ip.go +++ b/wgengine/packet/ip4.go @@ -11,61 +11,62 @@ "inet.af/netaddr" ) -// IP is an IPv4 address. -type IP uint32 +// IP4 is an IPv4 address. +type IP4 uint32 // NewIP converts a standard library IP address into an IP. // It panics if b is not an IPv4 address. -func NewIP(b net.IP) IP { +func NewIP4(b net.IP) IP4 { b4 := b.To4() if b4 == nil { panic(fmt.Sprintf("To4(%v) failed", b)) } - return IP(get32(b4)) + return IP4(get32(b4)) } // IPFromNetaddr converts a netaddr.IP to an IP. -func IPFromNetaddr(ip netaddr.IP) IP { +func IP4FromNetaddr(ip netaddr.IP) IP4 { ipbytes := ip.As4() - return IP(get32(ipbytes[:])) + return IP4(get32(ipbytes[:])) } // Netaddr converts an IP to a netaddr.IP. -func (ip IP) Netaddr() netaddr.IP { +func (ip IP4) Netaddr() netaddr.IP { return netaddr.IPv4(byte(ip>>24), byte(ip>>16), byte(ip>>8), byte(ip)) } -func (ip IP) String() string { +func (ip IP4) String() string { return fmt.Sprintf("%d.%d.%d.%d", byte(ip>>24), byte(ip>>16), byte(ip>>8), byte(ip)) } -func (ip IP) IsMulticast() bool { +func (ip IP4) IsMulticast() bool { return byte(ip>>24)&0xf0 == 0xe0 } -func (ip IP) IsLinkLocalUnicast() bool { +func (ip IP4) IsLinkLocalUnicast() bool { return byte(ip>>24) == 169 && byte(ip>>16) == 254 } -// IPProto is either a real IP protocol (ITCP, UDP, ...) or an special value like Unknown. -// If it is a real IP protocol, its value corresponds to its IP protocol number. -type IPProto uint8 +// IP4Proto is either a real IP protocol (TCP, UDP, ...) or an special +// value like Unknown. If it is a real IP protocol, its value +// corresponds to its IP protocol number. +type IP4Proto uint8 const ( // Unknown represents an unknown or unsupported protocol; it's deliberately the zero value. - Unknown IPProto = 0x00 - ICMP IPProto = 0x01 - IGMP IPProto = 0x02 - ICMPv6 IPProto = 0x3a - TCP IPProto = 0x06 - UDP IPProto = 0x11 + Unknown IP4Proto = 0x00 + ICMP IP4Proto = 0x01 + IGMP IP4Proto = 0x02 + ICMPv6 IP4Proto = 0x3a + TCP IP4Proto = 0x06 + UDP IP4Proto = 0x11 // Fragment is a special value. It's not really an IPProto value // so we're using the unassigned 0xFF value. // TODO(dmytro): special values should be taken out of here. - Fragment IPProto = 0xFF + Fragment IP4Proto = 0xFF ) -func (p IPProto) String() string { +func (p IP4Proto) String() string { switch p { case Fragment: return "Frag" @@ -81,20 +82,20 @@ func (p IPProto) String() string { } // IPHeader represents an IP packet header. -type IPHeader struct { - IPProto IPProto +type IP4Header struct { + IPProto IP4Proto IPID uint16 - SrcIP IP - DstIP IP + SrcIP IP4 + DstIP IP4 } const ipHeaderLength = 20 -func (IPHeader) Len() int { +func (IP4Header) Len() int { return ipHeaderLength } -func (h IPHeader) Marshal(buf []byte) error { +func (h IP4Header) Marshal(buf []byte) error { if len(buf) < ipHeaderLength { return errSmallBuffer } @@ -118,11 +119,10 @@ func (h IPHeader) Marshal(buf []byte) error { return nil } -// MarshalPseudo serializes the header into buf in pseudo format. -// It clobbers the header region, which is the first h.Length() bytes of buf. -// It explicitly initializes every byte of the header region, -// so pre-zeroing it on reuse is not required. It does not allocate memory. -func (h IPHeader) MarshalPseudo(buf []byte) error { +// MarshalPseudo serializes the header into buf in the "pseudo-header" +// form required when calculating UDP checksums. Overwrites the first +// h.Length() bytes of buf. +func (h IP4Header) MarshalPseudo(buf []byte) error { if len(buf) < ipHeaderLength { return errSmallBuffer } @@ -140,7 +140,8 @@ func (h IPHeader) MarshalPseudo(buf []byte) error { return nil } -func (h *IPHeader) ToResponse() { +// ToResponse implements Header. +func (h *IP4Header) ToResponse() { h.SrcIP, h.DstIP = h.DstIP, h.SrcIP // Flip the bits in the IPID. If incoming IPIDs are distinct, so are these. h.IPID = ^h.IPID diff --git a/wgengine/packet/packet.go b/wgengine/packet/packet.go index a82e84ab1..fe206c964 100644 --- a/wgengine/packet/packet.go +++ b/wgengine/packet/packet.go @@ -43,13 +43,13 @@ type ParsedPacket struct { // This is not the same as len(b) because b can have trailing zeros. length int - IPVersion uint8 // 4, 6, or 0 - IPProto IPProto // IP subprotocol (UDP, TCP, etc); the NextHeader field for IPv6 - SrcIP IP // IP source address (not used for IPv6) - DstIP IP // IP destination address (not used for IPv6) - SrcPort uint16 // TCP/UDP source port - DstPort uint16 // TCP/UDP destination port - TCPFlags uint8 // TCP flags (SYN, ACK, etc) + IPVersion uint8 // 4, 6, or 0 + IPProto IP4Proto // IP subprotocol (UDP, TCP, etc); the NextHeader field for IPv6 + SrcIP IP4 // IP source address (not used for IPv6) + DstIP IP4 // IP destination address (not used for IPv6) + SrcPort uint16 // TCP/UDP source port + DstPort uint16 // TCP/UDP destination port + TCPFlags uint8 // TCP flags (SYN, ACK, etc) } // NextHeader @@ -73,7 +73,7 @@ func (p *ParsedPacket) String() string { return sb.String() } -func writeIPPort(sb *strbuilder.Builder, ip IP, port uint16) { +func writeIPPort(sb *strbuilder.Builder, ip IP4, port uint16) { sb.WriteUint(uint64(byte(ip >> 24))) sb.WriteByte('.') sb.WriteUint(uint64(byte(ip >> 16))) @@ -122,9 +122,9 @@ func (q *ParsedPacket) Decode(b []byte) { q.IPVersion = (b[0] & 0xF0) >> 4 switch q.IPVersion { case 4: - q.IPProto = IPProto(b[9]) + q.IPProto = IP4Proto(b[9]) case 6: - q.IPProto = IPProto(b[6]) // "Next Header" field + q.IPProto = IP4Proto(b[6]) // "Next Header" field return default: q.IPVersion = 0 @@ -140,8 +140,8 @@ func (q *ParsedPacket) Decode(b []byte) { } // If it's valid IPv4, then the IP addresses are valid - q.SrcIP = IP(get32(b[12:16])) - q.DstIP = IP(get32(b[16:20])) + q.SrcIP = IP4(get32(b[12:16])) + q.DstIP = IP4(get32(b[16:20])) q.subofs = int((b[0] & 0x0F) << 2) sub := b[q.subofs:] @@ -224,9 +224,9 @@ func (q *ParsedPacket) Decode(b []byte) { } } -func (q *ParsedPacket) IPHeader() IPHeader { +func (q *ParsedPacket) IPHeader() IP4Header { ipid := get16(q.b[4:6]) - return IPHeader{ + return IP4Header{ IPID: ipid, IPProto: q.IPProto, SrcIP: q.SrcIP, @@ -234,19 +234,19 @@ func (q *ParsedPacket) IPHeader() IPHeader { } } -func (q *ParsedPacket) ICMPHeader() ICMPHeader { - return ICMPHeader{ - IPHeader: q.IPHeader(), - Type: ICMPType(q.b[q.subofs+0]), - Code: ICMPCode(q.b[q.subofs+1]), +func (q *ParsedPacket) ICMPHeader() ICMP4Header { + return ICMP4Header{ + IP4Header: q.IPHeader(), + Type: ICMP4Type(q.b[q.subofs+0]), + Code: ICMP4Code(q.b[q.subofs+1]), } } -func (q *ParsedPacket) UDPHeader() UDPHeader { - return UDPHeader{ - IPHeader: q.IPHeader(), - SrcPort: q.SrcPort, - DstPort: q.DstPort, +func (q *ParsedPacket) UDPHeader() UDP4Header { + return UDP4Header{ + IP4Header: q.IPHeader(), + SrcPort: q.SrcPort, + DstPort: q.DstPort, } } @@ -284,8 +284,8 @@ func (q *ParsedPacket) IsTCPSyn() bool { // IsError reports whether q is an IPv4 ICMP "Error" packet. func (q *ParsedPacket) IsError() bool { if q.IPProto == ICMP && len(q.b) >= q.subofs+8 { - switch ICMPType(q.b[q.subofs]) { - case ICMPUnreachable, ICMPTimeExceeded: + switch ICMP4Type(q.b[q.subofs]) { + case ICMP4Unreachable, ICMP4TimeExceeded: return true } } @@ -295,8 +295,8 @@ func (q *ParsedPacket) IsError() bool { // IsEchoRequest reports whether q is an IPv4 ICMP Echo Request. func (q *ParsedPacket) IsEchoRequest() bool { if q.IPProto == ICMP && len(q.b) >= q.subofs+8 { - return ICMPType(q.b[q.subofs]) == ICMPEchoRequest && - ICMPCode(q.b[q.subofs+1]) == ICMPNoCode + return ICMP4Type(q.b[q.subofs]) == ICMP4EchoRequest && + ICMP4Code(q.b[q.subofs+1]) == ICMP4NoCode } return false } @@ -304,8 +304,8 @@ func (q *ParsedPacket) IsEchoRequest() bool { // IsEchoRequest reports whether q is an IPv4 ICMP Echo Response. func (q *ParsedPacket) IsEchoResponse() bool { if q.IPProto == ICMP && len(q.b) >= q.subofs+8 { - return ICMPType(q.b[q.subofs]) == ICMPEchoReply && - ICMPCode(q.b[q.subofs+1]) == ICMPNoCode + return ICMP4Type(q.b[q.subofs]) == ICMP4EchoReply && + ICMP4Code(q.b[q.subofs+1]) == ICMP4NoCode } return false } diff --git a/wgengine/packet/packet_test.go b/wgengine/packet/packet_test.go index 8b286da51..4ed31ca5e 100644 --- a/wgengine/packet/packet_test.go +++ b/wgengine/packet/packet_test.go @@ -11,9 +11,9 @@ "testing" ) -func TestIPString(t *testing.T) { +func TestIP4String(t *testing.T) { const str = "1.2.3.4" - ip := NewIP(net.ParseIP(str)) + ip := NewIP4(net.ParseIP(str)) var got string allocs := testing.AllocsPerRun(1000, func() { @@ -49,8 +49,8 @@ func TestIPString(t *testing.T) { IPVersion: 4, IPProto: ICMP, - SrcIP: NewIP(net.ParseIP("1.2.3.4")), - DstIP: NewIP(net.ParseIP("5.6.7.8")), + SrcIP: NewIP4(net.ParseIP("1.2.3.4")), + DstIP: NewIP4(net.ParseIP("5.6.7.8")), SrcPort: 0, DstPort: 0, } @@ -75,8 +75,8 @@ func TestIPString(t *testing.T) { IPVersion: 4, IPProto: ICMP, - SrcIP: NewIP(net.ParseIP("1.2.3.4")), - DstIP: NewIP(net.ParseIP("5.6.7.8")), + SrcIP: NewIP4(net.ParseIP("1.2.3.4")), + DstIP: NewIP4(net.ParseIP("5.6.7.8")), SrcPort: 0, DstPort: 0, } @@ -131,8 +131,8 @@ func TestIPString(t *testing.T) { IPVersion: 4, IPProto: TCP, - SrcIP: NewIP(net.ParseIP("1.2.3.4")), - DstIP: NewIP(net.ParseIP("5.6.7.8")), + SrcIP: NewIP4(net.ParseIP("1.2.3.4")), + DstIP: NewIP4(net.ParseIP("5.6.7.8")), SrcPort: 123, DstPort: 567, TCPFlags: TCPSynAck, @@ -159,8 +159,8 @@ func TestIPString(t *testing.T) { IPVersion: 4, IPProto: UDP, - SrcIP: NewIP(net.ParseIP("1.2.3.4")), - DstIP: NewIP(net.ParseIP("5.6.7.8")), + SrcIP: NewIP4(net.ParseIP("1.2.3.4")), + DstIP: NewIP4(net.ParseIP("5.6.7.8")), SrcPort: 123, DstPort: 567, } @@ -185,8 +185,8 @@ func TestIPString(t *testing.T) { length: len(udpReplyBuffer), IPProto: UDP, - SrcIP: NewIP(net.ParseIP("1.2.3.4")), - DstIP: NewIP(net.ParseIP("5.6.7.8")), + SrcIP: NewIP4(net.ParseIP("1.2.3.4")), + DstIP: NewIP4(net.ParseIP("5.6.7.8")), SrcPort: 567, DstPort: 123, } diff --git a/wgengine/packet/udp.go b/wgengine/packet/udp4.go similarity index 77% rename from wgengine/packet/udp.go rename to wgengine/packet/udp4.go index 76cc9c922..ecc09b7bb 100644 --- a/wgengine/packet/udp.go +++ b/wgengine/packet/udp4.go @@ -5,8 +5,8 @@ package packet // UDPHeader represents an UDP packet header. -type UDPHeader struct { - IPHeader +type UDP4Header struct { + IP4Header SrcPort uint16 DstPort uint16 } @@ -17,11 +17,11 @@ type UDPHeader struct { udpTotalHeaderLength = ipHeaderLength + udpHeaderLength ) -func (UDPHeader) Len() int { +func (UDP4Header) Len() int { return udpTotalHeaderLength } -func (h UDPHeader) Marshal(buf []byte) error { +func (h UDP4Header) Marshal(buf []byte) error { if len(buf) < udpTotalHeaderLength { return errSmallBuffer } @@ -31,23 +31,23 @@ func (h UDPHeader) Marshal(buf []byte) error { // The caller does not need to set this. h.IPProto = UDP - length := len(buf) - h.IPHeader.Len() + length := len(buf) - h.IP4Header.Len() put16(buf[20:22], h.SrcPort) put16(buf[22:24], h.DstPort) put16(buf[24:26], uint16(length)) put16(buf[26:28], 0) // blank checksum - h.IPHeader.MarshalPseudo(buf) + h.IP4Header.MarshalPseudo(buf) // UDP checksum with IP pseudo header. put16(buf[26:28], ipChecksum(buf[8:])) - h.IPHeader.Marshal(buf) + h.IP4Header.Marshal(buf) return nil } -func (h *UDPHeader) ToResponse() { +func (h *UDP4Header) ToResponse() { h.SrcPort, h.DstPort = h.DstPort, h.SrcPort - h.IPHeader.ToResponse() + h.IP4Header.ToResponse() } diff --git a/wgengine/tstun/tun.go b/wgengine/tstun/tun.go index 77ebc52c4..ebba40b75 100644 --- a/wgengine/tstun/tun.go +++ b/wgengine/tstun/tun.go @@ -136,7 +136,7 @@ func WrapTUN(logf logger.Logf, tdev tun.Device) *TUN { // destination (the map keys). // // The map ownership passes to the TUN. It must be non-nil. -func (t *TUN) SetDestIPActivityFuncs(m map[packet.IP]func()) { +func (t *TUN) SetDestIPActivityFuncs(m map[packet.IP4]func()) { t.destIPActivity.Store(m) } @@ -282,7 +282,7 @@ func (t *TUN) Read(buf []byte, offset int) (int, error) { defer parsedPacketPool.Put(p) p.Decode(buf[offset : offset+n]) - if m, ok := t.destIPActivity.Load().(map[packet.IP]func()); ok { + if m, ok := t.destIPActivity.Load().(map[packet.IP4]func()); ok { if fn := m[p.DstIP]; fn != nil { fn() } diff --git a/wgengine/tstun/tun_test.go b/wgengine/tstun/tun_test.go index 3394a1aba..3d48e6286 100644 --- a/wgengine/tstun/tun_test.go +++ b/wgengine/tstun/tun_test.go @@ -16,9 +16,9 @@ "tailscale.com/wgengine/packet" ) -func udp(src, dst packet.IP, sport, dport uint16) []byte { - header := &packet.UDPHeader{ - IPHeader: packet.IPHeader{ +func udp(src, dst packet.IP4, sport, dport uint16) []byte { + header := &packet.UDP4Header{ + IP4Header: packet.IP4Header{ SrcIP: src, DstIP: dst, IPID: 0, @@ -29,11 +29,11 @@ func udp(src, dst packet.IP, sport, dport uint16) []byte { return packet.Generate(header, []byte("udp_payload")) } -func filterNet(ip, mask packet.IP) filter.Net { +func filterNet(ip, mask packet.IP4) filter.Net { return filter.Net{IP: ip, Mask: mask} } -func nets(ips []packet.IP) []filter.Net { +func nets(ips []packet.IP4) []filter.Net { out := make([]filter.Net, 0, len(ips)) for _, ip := range ips { out = append(out, filterNet(ip, filter.Netmask(32))) @@ -41,7 +41,7 @@ func nets(ips []packet.IP) []filter.Net { return out } -func ippr(ip packet.IP, start, end uint16) []filter.NetPortRange { +func ippr(ip packet.IP4, start, end uint16) []filter.NetPortRange { return []filter.NetPortRange{ filter.NetPortRange{ Net: filterNet(ip, filter.Netmask(32)), @@ -52,11 +52,11 @@ func ippr(ip packet.IP, start, end uint16) []filter.NetPortRange { func setfilter(logf logger.Logf, tun *TUN) { matches := filter.Matches{ - {Srcs: nets([]packet.IP{0x05060708}), Dsts: ippr(0x01020304, 89, 90)}, - {Srcs: nets([]packet.IP{0x01020304}), Dsts: ippr(0x05060708, 98, 98)}, + {Srcs: nets([]packet.IP4{0x05060708}), Dsts: ippr(0x01020304, 89, 90)}, + {Srcs: nets([]packet.IP4{0x01020304}), Dsts: ippr(0x05060708, 98, 98)}, } localNets := []filter.Net{ - filterNet(packet.IP(0x01020304), filter.Netmask(16)), + filterNet(packet.IP4(0x01020304), filter.Netmask(16)), } tun.SetFilter(filter.New(matches, localNets, nil, logf)) } diff --git a/wgengine/userspace.go b/wgengine/userspace.go index 3542126ed..39eb9e641 100644 --- a/wgengine/userspace.go +++ b/wgengine/userspace.go @@ -99,7 +99,7 @@ type userspaceEngine struct { // localAddrs is the set of IP addresses assigned to the local // tunnel interface. It's used to reflect local packets // incorrectly sent to us. - localAddrs atomic.Value // of map[packet.IP]bool + localAddrs atomic.Value // of map[packet.IP4]bool wgLock sync.Mutex // serializes all wgdev operations; see lock order comment below lastCfgFull wgcfg.Config @@ -108,8 +108,8 @@ type userspaceEngine struct { lastEngineSigTrim string // of trimmed wireguard config recvActivityAt map[tailcfg.DiscoKey]time.Time trimmedDisco map[tailcfg.DiscoKey]bool // set of disco keys of peers currently excluded from wireguard config - sentActivityAt map[packet.IP]*int64 // value is atomic int64 of unixtime - destIPActivityFuncs map[packet.IP]func() + sentActivityAt map[packet.IP4]*int64 // value is atomic int64 of unixtime + destIPActivityFuncs map[packet.IP4]func() mu sync.Mutex // guards following; see lock order comment below closing bool // Close was called (even if we're still closing) @@ -206,7 +206,7 @@ func newUserspaceEngineAdvanced(conf EngineConfig) (_ Engine, reterr error) { resolver: tsdns.NewResolver(rconf), pingers: make(map[wgcfg.Key]*pinger), } - e.localAddrs.Store(map[packet.IP]bool{}) + e.localAddrs.Store(map[packet.IP4]bool{}) e.linkState, _ = getLinkState() logf("link state: %+v", e.linkState) @@ -410,8 +410,8 @@ func (e *userspaceEngine) handleLocalPackets(p *packet.ParsedPacket, t *tstun.TU return filter.Accept } -func (e *userspaceEngine) isLocalAddr(ip packet.IP) bool { - localAddrs, ok := e.localAddrs.Load().(map[packet.IP]bool) +func (e *userspaceEngine) isLocalAddr(ip packet.IP4) bool { + localAddrs, ok := e.localAddrs.Load().(map[packet.IP4]bool) if !ok { e.logf("[unexpected] e.localAddrs was nil, can't check for loopback packet") return false @@ -447,10 +447,10 @@ func (e *userspaceEngine) pollResolver() { continue } - h := packet.UDPHeader{ - IPHeader: packet.IPHeader{ - SrcIP: packet.IP(magicDNSIP), - DstIP: packet.IPFromNetaddr(resp.Addr.IP), + h := packet.UDP4Header{ + IP4Header: packet.IP4Header{ + SrcIP: packet.IP4(magicDNSIP), + DstIP: packet.IP4FromNetaddr(resp.Addr.IP), }, SrcPort: magicDNSPort, DstPort: resp.Addr.Port, @@ -487,7 +487,7 @@ func (p *pinger) close() { <-p.done } -func (p *pinger) run(ctx context.Context, peerKey wgcfg.Key, ips []wgcfg.IP, srcIP packet.IP) { +func (p *pinger) run(ctx context.Context, peerKey wgcfg.Key, ips []wgcfg.IP, srcIP packet.IP4) { defer func() { p.e.mu.Lock() if p.e.pingers[peerKey] == p { @@ -498,12 +498,12 @@ func (p *pinger) run(ctx context.Context, peerKey wgcfg.Key, ips []wgcfg.IP, src close(p.done) }() - header := packet.ICMPHeader{ - IPHeader: packet.IPHeader{ + header := packet.ICMP4Header{ + IP4Header: packet.IP4Header{ SrcIP: srcIP, }, - Type: packet.ICMPEchoRequest, - Code: packet.ICMPNoCode, + Type: packet.ICMP4EchoRequest, + Code: packet.ICMP4NoCode, } // sendFreq is slightly longer than sprayFreq in magicsock to ensure @@ -513,9 +513,9 @@ func (p *pinger) run(ctx context.Context, peerKey wgcfg.Key, ips []wgcfg.IP, src const stopAfter = 3 * time.Second start := time.Now() - var dstIPs []packet.IP + var dstIPs []packet.IP4 for _, ip := range ips { - dstIPs = append(dstIPs, packet.NewIP(ip.IP())) + dstIPs = append(dstIPs, packet.NewIP4(ip.IP())) } payload := []byte("magicsock_spray") // no meaning @@ -551,11 +551,11 @@ func (p *pinger) run(ctx context.Context, peerKey wgcfg.Key, ips []wgcfg.IP, src // have advertised discovery keys. func (e *userspaceEngine) pinger(peerKey wgcfg.Key, ips []wgcfg.IP) { e.logf("generating initial ping traffic to %s (%v)", peerKey.ShortString(), ips) - var srcIP packet.IP + var srcIP packet.IP4 e.wgLock.Lock() if len(e.lastCfgFull.Addresses) > 0 { - srcIP = packet.NewIP(e.lastCfgFull.Addresses[0].IP.IP()) + srcIP = packet.NewIP4(e.lastCfgFull.Addresses[0].IP.IP()) } e.wgLock.Unlock() @@ -681,7 +681,7 @@ func (e *userspaceEngine) isActiveSince(dk tailcfg.DiscoKey, ip wgcfg.IP, t time if e.recvActivityAt[dk].After(t) { return true } - pip := packet.IP(binary.BigEndian.Uint32(ip.Addr[12:])) + pip := packet.IP4(binary.BigEndian.Uint32(ip.Addr[12:])) timePtr, ok := e.sentActivityAt[pip] if !ok { return false @@ -792,12 +792,12 @@ func (e *userspaceEngine) updateActivityMapsLocked(trackDisco []tailcfg.DiscoKey e.recvActivityAt = mr oldTime := e.sentActivityAt - e.sentActivityAt = make(map[packet.IP]*int64, len(oldTime)) + e.sentActivityAt = make(map[packet.IP4]*int64, len(oldTime)) oldFunc := e.destIPActivityFuncs - e.destIPActivityFuncs = make(map[packet.IP]func(), len(oldFunc)) + e.destIPActivityFuncs = make(map[packet.IP4]func(), len(oldFunc)) for _, wip := range trackIPs { - pip := packet.IP(binary.BigEndian.Uint32(wip.Addr[12:])) + pip := packet.IP4(binary.BigEndian.Uint32(wip.Addr[12:])) timePtr := oldTime[pip] if timePtr == nil { timePtr = new(int64) @@ -837,13 +837,13 @@ func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config) panic("routerCfg must not be nil") } - localAddrs := map[packet.IP]bool{} + localAddrs := map[packet.IP4]bool{} for _, addr := range routerCfg.LocalAddrs { // TODO: ipv6 if !addr.IP.Is4() { continue } - localAddrs[packet.IPFromNetaddr(addr.IP)] = true + localAddrs[packet.IP4FromNetaddr(addr.IP)] = true } e.localAddrs.Store(localAddrs)