net/packet: remove the custom IP4/IP6 types in favor of netaddr.IP.

Upstream netaddr has a change that makes it alloc-free, so it's safe to
use in hot codepaths. This gets rid of one of the many IP types in our
codebase.

Performance is currently worse across the board. This is likely due in
part to netaddr.IP being a larger value type (4b -> 24b for IPv4,
16b -> 24b for IPv6), and in other part due to missing low-hanging fruit
optimizations in netaddr. However, the regression is less bad than
it looks at first glance, because we'd micro-optimized packet.IP* in
the past few weeks. This change drops us back to roughly where we
were at the 1.2 release, but with the benefit of a significant
code and architectural simplification.

name                   old time/op    new time/op    delta
pkg:tailscale.com/net/packet goos:linux goarch:amd64
Decode/tcp4-8            12.2ns ± 5%    29.7ns ± 2%  +142.32%  (p=0.008 n=5+5)
Decode/tcp6-8            12.6ns ± 3%    65.1ns ± 2%  +418.47%  (p=0.008 n=5+5)
Decode/udp4-8            11.8ns ± 3%    30.5ns ± 2%  +157.94%  (p=0.008 n=5+5)
Decode/udp6-8            27.1ns ± 1%    65.7ns ± 2%  +142.36%  (p=0.016 n=4+5)
Decode/icmp4-8           24.6ns ± 2%    30.5ns ± 2%   +23.65%  (p=0.016 n=4+5)
Decode/icmp6-8           22.9ns ±51%    65.5ns ± 2%  +186.19%  (p=0.008 n=5+5)
Decode/igmp-8            18.1ns ±44%    30.2ns ± 1%   +66.89%  (p=0.008 n=5+5)
Decode/unknown-8         20.8ns ± 1%    10.6ns ± 9%   -49.11%  (p=0.016 n=4+5)
pkg:tailscale.com/wgengine/filter goos:linux goarch:amd64
Filter/icmp4-8           30.5ns ± 1%    77.9ns ± 3%  +155.01%  (p=0.008 n=5+5)
Filter/tcp4_syn_in-8     43.7ns ± 3%   123.0ns ± 3%  +181.72%  (p=0.008 n=5+5)
Filter/tcp4_syn_out-8    24.5ns ± 2%    45.7ns ± 6%   +86.22%  (p=0.008 n=5+5)
Filter/udp4_in-8         64.8ns ± 1%   210.0ns ± 2%  +223.87%  (p=0.008 n=5+5)
Filter/udp4_out-8         119ns ± 0%     278ns ± 0%  +133.78%  (p=0.016 n=4+5)
Filter/icmp6-8           40.3ns ± 2%   204.4ns ± 4%  +407.70%  (p=0.008 n=5+5)
Filter/tcp6_syn_in-8     35.3ns ± 3%   199.2ns ± 2%  +464.95%  (p=0.008 n=5+5)
Filter/tcp6_syn_out-8    32.8ns ± 2%    81.0ns ± 2%  +147.10%  (p=0.008 n=5+5)
Filter/udp6_in-8          106ns ± 2%     290ns ± 2%  +174.48%  (p=0.008 n=5+5)
Filter/udp6_out-8         184ns ± 2%     314ns ± 3%   +70.43%  (p=0.016 n=4+5)
pkg:tailscale.com/wgengine/tstun goos:linux goarch:amd64
Write-8                  9.02ns ± 3%    8.92ns ± 1%      ~     (p=0.421 n=5+5)

name                   old alloc/op   new alloc/op   delta
pkg:tailscale.com/net/packet goos:linux goarch:amd64
Decode/tcp4-8             0.00B          0.00B           ~     (all equal)
Decode/tcp6-8             0.00B          0.00B           ~     (all equal)
Decode/udp4-8             0.00B          0.00B           ~     (all equal)
Decode/udp6-8             0.00B          0.00B           ~     (all equal)
Decode/icmp4-8            0.00B          0.00B           ~     (all equal)
Decode/icmp6-8            0.00B          0.00B           ~     (all equal)
Decode/igmp-8             0.00B          0.00B           ~     (all equal)
Decode/unknown-8          0.00B          0.00B           ~     (all equal)
pkg:tailscale.com/wgengine/filter goos:linux goarch:amd64
Filter/icmp4-8            0.00B          0.00B           ~     (all equal)
Filter/tcp4_syn_in-8      0.00B          0.00B           ~     (all equal)
Filter/tcp4_syn_out-8     0.00B          0.00B           ~     (all equal)
Filter/udp4_in-8          0.00B          0.00B           ~     (all equal)
Filter/udp4_out-8         16.0B ± 0%     64.0B ± 0%  +300.00%  (p=0.008 n=5+5)
Filter/icmp6-8            0.00B          0.00B           ~     (all equal)
Filter/tcp6_syn_in-8      0.00B          0.00B           ~     (all equal)
Filter/tcp6_syn_out-8     0.00B          0.00B           ~     (all equal)
Filter/udp6_in-8          0.00B          0.00B           ~     (all equal)
Filter/udp6_out-8         48.0B ± 0%     64.0B ± 0%   +33.33%  (p=0.008 n=5+5)

name                   old allocs/op  new allocs/op  delta
pkg:tailscale.com/net/packet goos:linux goarch:amd64
Decode/tcp4-8              0.00           0.00           ~     (all equal)
Decode/tcp6-8              0.00           0.00           ~     (all equal)
Decode/udp4-8              0.00           0.00           ~     (all equal)
Decode/udp6-8              0.00           0.00           ~     (all equal)
Decode/icmp4-8             0.00           0.00           ~     (all equal)
Decode/icmp6-8             0.00           0.00           ~     (all equal)
Decode/igmp-8              0.00           0.00           ~     (all equal)
Decode/unknown-8           0.00           0.00           ~     (all equal)
pkg:tailscale.com/wgengine/filter goos:linux goarch:amd64
Filter/icmp4-8             0.00           0.00           ~     (all equal)
Filter/tcp4_syn_in-8       0.00           0.00           ~     (all equal)
Filter/tcp4_syn_out-8      0.00           0.00           ~     (all equal)
Filter/udp4_in-8           0.00           0.00           ~     (all equal)
Filter/udp4_out-8          1.00 ± 0%      1.00 ± 0%      ~     (all equal)
Filter/icmp6-8             0.00           0.00           ~     (all equal)
Filter/tcp6_syn_in-8       0.00           0.00           ~     (all equal)
Filter/tcp6_syn_out-8      0.00           0.00           ~     (all equal)
Filter/udp6_in-8           0.00           0.00           ~     (all equal)
Filter/udp6_out-8          1.00 ± 0%      1.00 ± 0%      ~     (all equal)

Signed-off-by: David Anderson <danderson@tailscale.com>
This commit is contained in:
David Anderson 2020-12-19 16:43:25 -08:00 committed by Brad Fitzpatrick
parent d0baece5fa
commit cb96b14bf4
13 changed files with 323 additions and 827 deletions

View File

@ -6,47 +6,11 @@
import ( import (
"encoding/binary" "encoding/binary"
"fmt" "errors"
"inet.af/netaddr" "inet.af/netaddr"
) )
// IP4 is an IPv4 address.
type IP4 uint32
// IPFromNetaddr converts a netaddr.IP to an IP4. Panics if !ip.Is4.
func IP4FromNetaddr(ip netaddr.IP) IP4 {
ipbytes := ip.As4()
return IP4(binary.BigEndian.Uint32(ipbytes[:]))
}
// Netaddr converts ip to a netaddr.IP.
func (ip IP4) Netaddr() netaddr.IP {
return netaddr.IPv4(byte(ip>>24), byte(ip>>16), byte(ip>>8), byte(ip))
}
func (ip IP4) String() string {
return fmt.Sprintf("%d.%d.%d.%d", byte(ip>>24), byte(ip>>16), byte(ip>>8), byte(ip))
}
// IsMulticast returns whether ip is a multicast address.
func (ip IP4) IsMulticast() bool {
return byte(ip>>24)&0xf0 == 0xe0
}
// IsLinkLocalUnicast returns whether ip is a link-local unicast
// address.
func (ip IP4) IsLinkLocalUnicast() bool {
return byte(ip>>24) == 169 && byte(ip>>16) == 254
}
// IsMostLinkLocalUnicast returns whether ip is a link-local unicast
// address other than the magical "169.254.169.254" address used by
// GCP DNS.
func (ip IP4) IsMostLinkLocalUnicast() bool {
return ip.IsLinkLocalUnicast() && ip != 0xA9FEA9FE
}
// ip4HeaderLength is the length of an IPv4 header with no IP options. // ip4HeaderLength is the length of an IPv4 header with no IP options.
const ip4HeaderLength = 20 const ip4HeaderLength = 20
@ -54,8 +18,8 @@ func (ip IP4) IsMostLinkLocalUnicast() bool {
type IP4Header struct { type IP4Header struct {
IPProto IPProto IPProto IPProto
IPID uint16 IPID uint16
SrcIP IP4 Src netaddr.IP
DstIP IP4 Dst netaddr.IP
} }
// Len implements Header. // Len implements Header.
@ -63,6 +27,8 @@ func (h IP4Header) Len() int {
return ip4HeaderLength return ip4HeaderLength
} }
var errWrongFamily = errors.New("wrong address family for src/dst IP")
// Marshal implements Header. // Marshal implements Header.
func (h IP4Header) Marshal(buf []byte) error { func (h IP4Header) Marshal(buf []byte) error {
if len(buf) < h.Len() { if len(buf) < h.Len() {
@ -71,6 +37,9 @@ func (h IP4Header) Marshal(buf []byte) error {
if len(buf) > maxPacketLength { if len(buf) > maxPacketLength {
return errLargePacket return errLargePacket
} }
if !h.Src.Is4() || !h.Dst.Is4() {
return errWrongFamily
}
buf[0] = 0x40 | (byte(h.Len() >> 2)) // IPv4 + IHL buf[0] = 0x40 | (byte(h.Len() >> 2)) // IPv4 + IHL
buf[1] = 0x00 // DSCP + ECN buf[1] = 0x00 // DSCP + ECN
@ -83,8 +52,10 @@ func (h IP4Header) Marshal(buf []byte) error {
// it later, because the checksum computation runs over these // it later, because the checksum computation runs over these
// bytes and expects them to be zero. // bytes and expects them to be zero.
binary.BigEndian.PutUint16(buf[10:12], 0) binary.BigEndian.PutUint16(buf[10:12], 0)
binary.BigEndian.PutUint32(buf[12:16], uint32(h.SrcIP)) // Src src := h.Src.As4()
binary.BigEndian.PutUint32(buf[16:20], uint32(h.DstIP)) // Dst dst := h.Dst.As4()
copy(buf[12:16], src[:])
copy(buf[16:20], dst[:])
binary.BigEndian.PutUint16(buf[10:12], ip4Checksum(buf[0:20])) // Checksum binary.BigEndian.PutUint16(buf[10:12], ip4Checksum(buf[0:20])) // Checksum
@ -93,7 +64,7 @@ func (h IP4Header) Marshal(buf []byte) error {
// ToResponse implements Header. // ToResponse implements Header.
func (h *IP4Header) ToResponse() { func (h *IP4Header) ToResponse() {
h.SrcIP, h.DstIP = h.DstIP, h.SrcIP h.Src, h.Dst = h.Dst, h.Src
// Flip the bits in the IPID. If incoming IPIDs are distinct, so are these. // Flip the bits in the IPID. If incoming IPIDs are distinct, so are these.
h.IPID = ^h.IPID h.IPID = ^h.IPID
} }
@ -135,8 +106,9 @@ func (h IP4Header) marshalPseudo(buf []byte) error {
} }
length := len(buf) - h.Len() length := len(buf) - h.Len()
binary.BigEndian.PutUint32(buf[8:12], uint32(h.SrcIP)) src, dst := h.Src.As4(), h.Dst.As4()
binary.BigEndian.PutUint32(buf[12:16], uint32(h.DstIP)) copy(buf[8:12], src[:])
copy(buf[12:16], dst[:])
buf[16] = 0x0 buf[16] = 0x0
buf[17] = uint8(h.IPProto) buf[17] = uint8(h.IPProto)
binary.BigEndian.PutUint16(buf[18:20], uint16(length)) binary.BigEndian.PutUint16(buf[18:20], uint16(length))

View File

@ -6,49 +6,10 @@
import ( import (
"encoding/binary" "encoding/binary"
"fmt"
"inet.af/netaddr" "inet.af/netaddr"
) )
// IP6 is an IPv6 address.
type IP6 struct {
Hi, Lo uint64
}
// IP6FromRaw16 converts a raw 16-byte IPv6 address to an IP6.
func IP6FromRaw16(ip [16]byte) IP6 {
return IP6{binary.BigEndian.Uint64(ip[:8]), binary.BigEndian.Uint64(ip[8:])}
}
// IP6FromNetaddr converts a netaddr.IP to an IP6. Panics if !ip.Is6.
func IP6FromNetaddr(ip netaddr.IP) IP6 {
if !ip.Is6() {
panic(fmt.Sprintf("IP6FromNetaddr called with non-v6 addr %q", ip))
}
return IP6FromRaw16(ip.As16())
}
// Netaddr converts ip to a netaddr.IP.
func (ip IP6) Netaddr() netaddr.IP {
var b [16]byte
binary.BigEndian.PutUint64(b[:8], ip.Hi)
binary.BigEndian.PutUint64(b[8:], ip.Lo)
return netaddr.IPFrom16(b)
}
func (ip IP6) String() string {
return ip.Netaddr().String()
}
func (ip IP6) IsMulticast() bool {
return (ip.Hi >> 56) == 0xFF
}
func (ip IP6) IsLinkLocalUnicast() bool {
return (ip.Hi >> 48) == 0xFE80
}
// ip6HeaderLength is the length of an IPv6 header with no IP options. // ip6HeaderLength is the length of an IPv6 header with no IP options.
const ip6HeaderLength = 40 const ip6HeaderLength = 40
@ -56,8 +17,8 @@ func (ip IP6) IsLinkLocalUnicast() bool {
type IP6Header struct { type IP6Header struct {
IPProto IPProto IPProto IPProto
IPID uint32 // only lower 20 bits used IPID uint32 // only lower 20 bits used
SrcIP IP6 Src netaddr.IP
DstIP IP6 Dst netaddr.IP
} }
// Len implements Header. // Len implements Header.
@ -79,17 +40,16 @@ func (h IP6Header) Marshal(buf []byte) error {
binary.BigEndian.PutUint16(buf[4:6], uint16(len(buf)-ip6HeaderLength)) // Total length binary.BigEndian.PutUint16(buf[4:6], uint16(len(buf)-ip6HeaderLength)) // Total length
buf[6] = uint8(h.IPProto) // Inner protocol buf[6] = uint8(h.IPProto) // Inner protocol
buf[7] = 64 // TTL buf[7] = 64 // TTL
binary.BigEndian.PutUint64(buf[8:16], h.SrcIP.Hi) src, dst := h.Src.As16(), h.Dst.As16()
binary.BigEndian.PutUint64(buf[16:24], h.SrcIP.Lo) copy(buf[8:24], src[:])
binary.BigEndian.PutUint64(buf[24:32], h.DstIP.Hi) copy(buf[24:40], dst[:])
binary.BigEndian.PutUint64(buf[32:40], h.DstIP.Lo)
return nil return nil
} }
// ToResponse implements Header. // ToResponse implements Header.
func (h *IP6Header) ToResponse() { func (h *IP6Header) ToResponse() {
h.SrcIP, h.DstIP = h.DstIP, h.SrcIP h.Src, h.Dst = h.Dst, h.Src
// Flip the bits in the IPID. If incoming IPIDs are distinct, so are these. // Flip the bits in the IPID. If incoming IPIDs are distinct, so are these.
h.IPID = (^h.IPID) & 0x000FFFFF h.IPID = (^h.IPID) & 0x000FFFFF
} }
@ -104,10 +64,9 @@ func (h IP6Header) marshalPseudo(buf []byte) error {
return errLargePacket return errLargePacket
} }
binary.BigEndian.PutUint64(buf[:8], h.SrcIP.Hi) src, dst := h.Src.As16(), h.Dst.As16()
binary.BigEndian.PutUint64(buf[8:16], h.SrcIP.Lo) copy(buf[:16], src[:])
binary.BigEndian.PutUint64(buf[16:24], h.DstIP.Hi) copy(buf[16:32], dst[:])
binary.BigEndian.PutUint64(buf[24:32], h.DstIP.Lo)
binary.BigEndian.PutUint32(buf[32:36], uint32(len(buf)-h.Len())) binary.BigEndian.PutUint32(buf[32:36], uint32(len(buf)-h.Len()))
buf[36] = 0 buf[36] = 0
buf[37] = 0 buf[37] = 0

View File

@ -7,8 +7,10 @@
import ( import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"net"
"strings" "strings"
"inet.af/netaddr"
"tailscale.com/types/strbuilder" "tailscale.com/types/strbuilder"
) )
@ -38,64 +40,50 @@ type Parsed struct {
IPVersion uint8 IPVersion uint8
// IPProto is the IP subprotocol (UDP, TCP, etc.). Valid iff IPVersion != 0. // IPProto is the IP subprotocol (UDP, TCP, etc.). Valid iff IPVersion != 0.
IPProto IPProto IPProto IPProto
// SrcIP4 is the IPv4 source address. Valid iff IPVersion == 4. // SrcIP4 is the source address. Family matches IPVersion. Port is
SrcIP4 IP4 // valid iff IPProto == TCP || IPProto == UDP.
// DstIP4 is the IPv4 destination address. Valid iff IPVersion == 4. Src netaddr.IPPort
DstIP4 IP4 // DstIP4 is the destination address. Family matches IPVersion.
// SrcIP6 is the IPv6 source address. Valid iff IPVersion == 6. Dst netaddr.IPPort
SrcIP6 IP6
// DstIP6 is the IPv6 destination address. Valid iff IPVersion == 6.
DstIP6 IP6
// SrcPort is the TCP/UDP source port. Valid iff IPProto == TCP || IPProto == UDP.
SrcPort uint16
// DstPort is the TCP/UDP source port. Valid iff IPProto == TCP || IPProto == UDP.
DstPort uint16
// TCPFlags is the packet's TCP flag bigs. Valid iff IPProto == TCP. // TCPFlags is the packet's TCP flag bigs. Valid iff IPProto == TCP.
TCPFlags uint8 TCPFlags uint8
} }
func (p *Parsed) String() string { func (p *Parsed) String() string {
switch p.IPVersion { if p.IPVersion != 4 && p.IPVersion != 6 {
case 4:
sb := strbuilder.Get()
sb.WriteString(p.IPProto.String())
sb.WriteByte('{')
writeIP4Port(sb, p.SrcIP4, p.SrcPort)
sb.WriteString(" > ")
writeIP4Port(sb, p.DstIP4, p.DstPort)
sb.WriteByte('}')
return sb.String()
case 6:
sb := strbuilder.Get()
sb.WriteString(p.IPProto.String())
sb.WriteByte('{')
writeIP6Port(sb, p.SrcIP6, p.SrcPort)
sb.WriteString(" > ")
writeIP6Port(sb, p.DstIP6, p.DstPort)
sb.WriteByte('}')
return sb.String()
default:
return "Unknown{???}" return "Unknown{???}"
} }
sb := strbuilder.Get()
sb.WriteString(p.IPProto.String())
sb.WriteByte('{')
writeIPPort(sb, p.Src)
sb.WriteString(" > ")
writeIPPort(sb, p.Dst)
sb.WriteByte('}')
return sb.String()
} }
func writeIP4Port(sb *strbuilder.Builder, ip IP4, port uint16) { // writeIPPort writes ipp.String() into sb, with fewer allocations.
sb.WriteUint(uint64(byte(ip >> 24))) //
sb.WriteByte('.') // TODO: make netaddr more efficient in this area, and retire this func.
sb.WriteUint(uint64(byte(ip >> 16))) func writeIPPort(sb *strbuilder.Builder, ipp netaddr.IPPort) {
sb.WriteByte('.') if ipp.IP.Is4() {
sb.WriteUint(uint64(byte(ip >> 8))) raw := ipp.IP.As4()
sb.WriteByte('.') sb.WriteUint(uint64(raw[0]))
sb.WriteUint(uint64(byte(ip))) sb.WriteByte('.')
sb.WriteByte(':') sb.WriteUint(uint64(raw[1]))
sb.WriteUint(uint64(port)) sb.WriteByte('.')
} sb.WriteUint(uint64(raw[2]))
sb.WriteByte('.')
func writeIP6Port(sb *strbuilder.Builder, ip IP6, port uint16) { sb.WriteUint(uint64(raw[3]))
sb.WriteByte('[') sb.WriteByte(':')
sb.WriteString(ip.Netaddr().String()) // TODO: faster? } else {
sb.WriteString("]:") sb.WriteByte('[')
sb.WriteUint(uint64(port)) sb.WriteString(ipp.IP.String()) // TODO: faster?
sb.WriteString("]:")
}
sb.WriteUint(uint64(ipp.Port))
} }
// Decode extracts data from the packet in b into q. // Decode extracts data from the packet in b into q.
@ -140,8 +128,8 @@ func (q *Parsed) decode4(b []byte) {
} }
// If it's valid IPv4, then the IP addresses are valid // If it's valid IPv4, then the IP addresses are valid
q.SrcIP4 = IP4(binary.BigEndian.Uint32(b[12:16])) q.Src.IP = netaddr.IPv4(b[12], b[13], b[14], b[15])
q.DstIP4 = IP4(binary.BigEndian.Uint32(b[16:20])) q.Dst.IP = netaddr.IPv4(b[16], b[17], b[18], b[19])
q.subofs = int((b[0] & 0x0F) << 2) q.subofs = int((b[0] & 0x0F) << 2)
if q.subofs > q.length { if q.subofs > q.length {
@ -183,8 +171,8 @@ func (q *Parsed) decode4(b []byte) {
q.IPProto = Unknown q.IPProto = Unknown
return return
} }
q.SrcPort = 0 q.Src.Port = 0
q.DstPort = 0 q.Dst.Port = 0
q.dataofs = q.subofs + icmp4HeaderLength q.dataofs = q.subofs + icmp4HeaderLength
return return
case IGMP: case IGMP:
@ -196,8 +184,8 @@ func (q *Parsed) decode4(b []byte) {
q.IPProto = Unknown q.IPProto = Unknown
return return
} }
q.SrcPort = binary.BigEndian.Uint16(sub[0:2]) q.Src.Port = binary.BigEndian.Uint16(sub[0:2])
q.DstPort = binary.BigEndian.Uint16(sub[2:4]) q.Dst.Port = binary.BigEndian.Uint16(sub[2:4])
q.TCPFlags = sub[13] & 0x3F q.TCPFlags = sub[13] & 0x3F
headerLength := (sub[12] & 0xF0) >> 2 headerLength := (sub[12] & 0xF0) >> 2
q.dataofs = q.subofs + int(headerLength) q.dataofs = q.subofs + int(headerLength)
@ -207,8 +195,8 @@ func (q *Parsed) decode4(b []byte) {
q.IPProto = Unknown q.IPProto = Unknown
return return
} }
q.SrcPort = binary.BigEndian.Uint16(sub[0:2]) q.Src.Port = binary.BigEndian.Uint16(sub[0:2])
q.DstPort = binary.BigEndian.Uint16(sub[2:4]) q.Dst.Port = binary.BigEndian.Uint16(sub[2:4])
q.dataofs = q.subofs + udpHeaderLength q.dataofs = q.subofs + udpHeaderLength
return return
default: default:
@ -249,10 +237,10 @@ func (q *Parsed) decode6(b []byte) {
return return
} }
q.SrcIP6.Hi = binary.BigEndian.Uint64(b[8:16]) // okay to ignore `ok` here, because IPs pulled from packets are
q.SrcIP6.Lo = binary.BigEndian.Uint64(b[16:24]) // always well-formed stdlib IPs.
q.DstIP6.Hi = binary.BigEndian.Uint64(b[24:32]) q.Src.IP, _ = netaddr.FromStdIP(net.IP(b[8:24]))
q.DstIP6.Lo = binary.BigEndian.Uint64(b[32:40]) q.Dst.IP, _ = netaddr.FromStdIP(net.IP(b[24:40]))
// We don't support any IPv6 extension headers. Don't try to // We don't support any IPv6 extension headers. Don't try to
// be clever. Therefore, the IP subprotocol always starts at // be clever. Therefore, the IP subprotocol always starts at
@ -276,16 +264,16 @@ func (q *Parsed) decode6(b []byte) {
q.IPProto = Unknown q.IPProto = Unknown
return return
} }
q.SrcPort = 0 q.Src.Port = 0
q.DstPort = 0 q.Dst.Port = 0
q.dataofs = q.subofs + icmp6HeaderLength q.dataofs = q.subofs + icmp6HeaderLength
case TCP: case TCP:
if len(sub) < tcpHeaderLength { if len(sub) < tcpHeaderLength {
q.IPProto = Unknown q.IPProto = Unknown
return return
} }
q.SrcPort = binary.BigEndian.Uint16(sub[0:2]) q.Src.Port = binary.BigEndian.Uint16(sub[0:2])
q.DstPort = binary.BigEndian.Uint16(sub[2:4]) q.Dst.Port = binary.BigEndian.Uint16(sub[2:4])
q.TCPFlags = sub[13] & 0x3F q.TCPFlags = sub[13] & 0x3F
headerLength := (sub[12] & 0xF0) >> 2 headerLength := (sub[12] & 0xF0) >> 2
q.dataofs = q.subofs + int(headerLength) q.dataofs = q.subofs + int(headerLength)
@ -295,8 +283,8 @@ func (q *Parsed) decode6(b []byte) {
q.IPProto = Unknown q.IPProto = Unknown
return return
} }
q.SrcPort = binary.BigEndian.Uint16(sub[0:2]) q.Src.Port = binary.BigEndian.Uint16(sub[0:2])
q.DstPort = binary.BigEndian.Uint16(sub[2:4]) q.Dst.Port = binary.BigEndian.Uint16(sub[2:4])
q.dataofs = q.subofs + udpHeaderLength q.dataofs = q.subofs + udpHeaderLength
default: default:
q.IPProto = Unknown q.IPProto = Unknown
@ -312,8 +300,8 @@ func (q *Parsed) IP4Header() IP4Header {
return IP4Header{ return IP4Header{
IPID: ipid, IPID: ipid,
IPProto: q.IPProto, IPProto: q.IPProto,
SrcIP: q.SrcIP4, Src: q.Src.IP,
DstIP: q.DstIP4, Dst: q.Dst.IP,
} }
} }
@ -334,8 +322,8 @@ func (q *Parsed) UDP4Header() UDP4Header {
} }
return UDP4Header{ return UDP4Header{
IP4Header: q.IP4Header(), IP4Header: q.IP4Header(),
SrcPort: q.SrcPort, SrcPort: q.Src.Port,
DstPort: q.DstPort, DstPort: q.Dst.Port,
} }
} }

View File

@ -12,54 +12,12 @@
"inet.af/netaddr" "inet.af/netaddr"
) )
func mustIP4(s string) IP4 { func mustIPPort(s string) netaddr.IPPort {
ip, err := netaddr.ParseIP(s) ipp, err := netaddr.ParseIPPort(s)
if err != nil { if err != nil {
panic(err) panic(err)
} }
return IP4FromNetaddr(ip) return ipp
}
func mustIP6(s string) IP6 {
ip, err := netaddr.ParseIP(s)
if err != nil {
panic(err)
}
return IP6FromNetaddr(ip)
}
func TestIP4String(t *testing.T) {
const str = "1.2.3.4"
ip := mustIP4(str)
var got string
allocs := testing.AllocsPerRun(1000, func() {
got = ip.String()
})
if got != str {
t.Errorf("got %q; want %q", got, str)
}
if allocs != 1 {
t.Errorf("allocs = %v; want 1", allocs)
}
}
func TestIP6String(t *testing.T) {
const str = "2607:f8b0:400a:809::200e"
ip := mustIP6(str)
var got string
allocs := testing.AllocsPerRun(1000, func() {
got = ip.String()
})
if got != str {
t.Errorf("got %q; want %q", got, str)
}
if allocs != 1 {
t.Errorf("allocs = %v; want 1", allocs)
}
} }
var icmp4RequestBuffer = []byte{ var icmp4RequestBuffer = []byte{
@ -83,10 +41,8 @@ func TestIP6String(t *testing.T) {
IPVersion: 4, IPVersion: 4,
IPProto: ICMPv4, IPProto: ICMPv4,
SrcIP4: mustIP4("1.2.3.4"), Src: mustIPPort("1.2.3.4:0"),
DstIP4: mustIP4("5.6.7.8"), Dst: mustIPPort("5.6.7.8:0"),
SrcPort: 0,
DstPort: 0,
} }
var icmp4ReplyBuffer = []byte{ var icmp4ReplyBuffer = []byte{
@ -109,10 +65,8 @@ func TestIP6String(t *testing.T) {
IPVersion: 4, IPVersion: 4,
IPProto: ICMPv4, IPProto: ICMPv4,
SrcIP4: mustIP4("1.2.3.4"), Src: mustIPPort("1.2.3.4:0"),
DstIP4: mustIP4("5.6.7.8"), Dst: mustIPPort("5.6.7.8:0"),
SrcPort: 0,
DstPort: 0,
} }
// ICMPv6 Router Solicitation // ICMPv6 Router Solicitation
@ -132,8 +86,8 @@ func TestIP6String(t *testing.T) {
length: len(icmp6PacketBuffer), length: len(icmp6PacketBuffer),
IPVersion: 6, IPVersion: 6,
IPProto: ICMPv6, IPProto: ICMPv6,
SrcIP6: mustIP6("fe80::fb57:1dea:9c39:8fb7"), Src: mustIPPort("[fe80::fb57:1dea:9c39:8fb7]:0"),
DstIP6: mustIP6("ff02::2"), Dst: mustIPPort("[ff02::2]:0"),
} }
// This is a malformed IPv4 packet. // This is a malformed IPv4 packet.
@ -170,10 +124,8 @@ func TestIP6String(t *testing.T) {
IPVersion: 4, IPVersion: 4,
IPProto: TCP, IPProto: TCP,
SrcIP4: mustIP4("1.2.3.4"), Src: mustIPPort("1.2.3.4:123"),
DstIP4: mustIP4("5.6.7.8"), Dst: mustIPPort("5.6.7.8:567"),
SrcPort: 123,
DstPort: 567,
TCPFlags: TCPSynAck, TCPFlags: TCPSynAck,
} }
@ -198,10 +150,8 @@ func TestIP6String(t *testing.T) {
IPVersion: 6, IPVersion: 6,
IPProto: TCP, IPProto: TCP,
SrcIP6: mustIP6("2001:559:bc13:5400:1749:4628:3934:e1b"), Src: mustIPPort("[2001:559:bc13:5400:1749:4628:3934:e1b]:42080"),
DstIP6: mustIP6("2607:f8b0:400a:809::200e"), Dst: mustIPPort("[2607:f8b0:400a:809::200e]:80"),
SrcPort: 42080,
DstPort: 80,
TCPFlags: TCPSyn, TCPFlags: TCPSyn,
} }
@ -226,10 +176,8 @@ func TestIP6String(t *testing.T) {
IPVersion: 4, IPVersion: 4,
IPProto: UDP, IPProto: UDP,
SrcIP4: mustIP4("1.2.3.4"), Src: mustIPPort("1.2.3.4:123"),
DstIP4: mustIP4("5.6.7.8"), Dst: mustIPPort("5.6.7.8:567"),
SrcPort: 123,
DstPort: 567,
} }
var invalid4RequestBuffer = []byte{ var invalid4RequestBuffer = []byte{
@ -250,8 +198,8 @@ func TestIP6String(t *testing.T) {
IPVersion: 4, IPVersion: 4,
IPProto: Unknown, IPProto: Unknown,
SrcIP4: mustIP4("1.2.3.4"), Src: mustIPPort("1.2.3.4:0"),
DstIP4: mustIP4("5.6.7.8"), Dst: mustIPPort("5.6.7.8:0"),
} }
var udp6RequestBuffer = []byte{ var udp6RequestBuffer = []byte{
@ -275,10 +223,8 @@ func TestIP6String(t *testing.T) {
IPVersion: 6, IPVersion: 6,
IPProto: UDP, IPProto: UDP,
SrcIP6: mustIP6("2001:559:bc13:5400:1749:4628:3934:e1b"), Src: mustIPPort("[2001:559:bc13:5400:1749:4628:3934:e1b]:54276"),
DstIP6: mustIP6("2607:f8b0:400a:809::200e"), Dst: mustIPPort("[2607:f8b0:400a:809::200e]:443"),
SrcPort: 54276,
DstPort: 443,
} }
var udp4ReplyBuffer = []byte{ var udp4ReplyBuffer = []byte{
@ -301,10 +247,8 @@ func TestIP6String(t *testing.T) {
length: len(udp4ReplyBuffer), length: len(udp4ReplyBuffer),
IPProto: UDP, IPProto: UDP,
SrcIP4: mustIP4("1.2.3.4"), Src: mustIPPort("1.2.3.4:567"),
DstIP4: mustIP4("5.6.7.8"), Dst: mustIPPort("5.6.7.8:123"),
SrcPort: 567,
DstPort: 123,
} }
var igmpPacketBuffer = []byte{ var igmpPacketBuffer = []byte{
@ -326,8 +270,8 @@ func TestIP6String(t *testing.T) {
IPVersion: 4, IPVersion: 4,
IPProto: IGMP, IPProto: IGMP,
SrcIP4: mustIP4("192.168.1.82"), Src: mustIPPort("192.168.1.82:0"),
DstIP4: mustIP4("224.0.0.251"), Dst: mustIPPort("224.0.0.251:0"),
} }
func TestParsed(t *testing.T) { func TestParsed(t *testing.T) {

View File

@ -25,45 +25,33 @@ type Filter struct {
// tailscale must have a destination within local4 or local6, // tailscale must have a destination within local4 or local6,
// regardless of the policy filter below. Zero values reject // regardless of the policy filter below. Zero values reject
// all incoming traffic. // all incoming traffic.
local4 []net4 local4 []netaddr.IPPrefix
local6 []net6 local6 []netaddr.IPPrefix
// matches4 and matches6 are lists of match->action rules // matches4 and matches6 are lists of match->action rules
// applied to all packets arriving over tailscale // applied to all packets arriving over tailscale
// tunnels. Matches are checked in order, and processing stops // tunnels. Matches are checked in order, and processing stops
// at the first matching rule. The default policy if no rules // at the first matching rule. The default policy if no rules
// match is to drop the packet. // match is to drop the packet.
matches4 matches4 matches4 matches
matches6 matches6 matches6 matches
// state is the connection tracking state attached to this // state is the connection tracking state attached to this
// filter. It is used to allow incoming traffic that is a response // filter. It is used to allow incoming traffic that is a response
// to an outbound connection that this node made, even if those // to an outbound connection that this node made, even if those
// incoming packets don't get accepted by matches above. // incoming packets don't get accepted by matches above.
state4 *filterState state *filterState
state6 *filterState
} }
// tuple4 is a 4-tuple of source and destination IPv4 and port. It's // tuple is a 4-tuple of source and destination IP and port. It's used
// used as a lookup key in filterState. // as a lookup key in filterState.
type tuple4 struct { type tuple struct {
SrcIP packet.IP4 Src netaddr.IPPort
DstIP packet.IP4 Dst netaddr.IPPort
SrcPort uint16
DstPort uint16
}
// tuple6 is a 4-tuple of source and destination IPv6 and port. It's
// used as a lookup key in filterState.
type tuple6 struct {
SrcIP packet.IP6
DstIP packet.IP6
SrcPort uint16
DstPort uint16
} }
// filterState is a state cache of past seen packets. // filterState is a state cache of past seen packets.
type filterState struct { type filterState struct {
mu sync.Mutex mu sync.Mutex
lru *lru.Cache // of tuple4 or tuple6 lru *lru.Cache // of tuple
} }
// lruMax is the size of the LRU cache in filterState. // lruMax is the size of the LRU cache in filterState.
@ -148,30 +136,58 @@ func NewAllowNone(logf logger.Logf) *Filter {
// shares state with the previous one, to enable changing rules at // shares state with the previous one, to enable changing rules at
// runtime without breaking existing stateful flows. // runtime without breaking existing stateful flows.
func New(matches []Match, localNets []netaddr.IPPrefix, shareStateWith *Filter, logf logger.Logf) *Filter { func New(matches []Match, localNets []netaddr.IPPrefix, shareStateWith *Filter, logf logger.Logf) *Filter {
var state4, state6 *filterState var state *filterState
if shareStateWith != nil { if shareStateWith != nil {
state4 = shareStateWith.state4 state = shareStateWith.state
state6 = shareStateWith.state6
} else { } else {
state4 = &filterState{ state = &filterState{
lru: lru.New(lruMax),
}
state6 = &filterState{
lru: lru.New(lruMax), lru: lru.New(lruMax),
} }
} }
f := &Filter{ f := &Filter{
logf: logf, logf: logf,
matches4: newMatches4(matches), matches4: matchesFamily(matches, netaddr.IP.Is4),
matches6: newMatches6(matches), matches6: matchesFamily(matches, netaddr.IP.Is6),
local4: nets4FromIPPrefixes(localNets), local4: netsFamily(localNets, netaddr.IP.Is4),
local6: nets6FromIPPrefixes(localNets), local6: netsFamily(localNets, netaddr.IP.Is6),
state4: state4, state: state,
state6: state6,
} }
return f return f
} }
func netsFamily(nets []netaddr.IPPrefix, keep func(netaddr.IP) bool) []netaddr.IPPrefix {
var ret []netaddr.IPPrefix
for _, net := range nets {
if keep(net.IP) {
ret = append(ret, net)
}
}
return ret
}
// matchesFamily returns the subset of ms for which keep(srcNet.IP)
// and keep(dstNet.IP) are both true.
func matchesFamily(ms matches, keep func(netaddr.IP) bool) matches {
var ret matches
for _, m := range ms {
var retm Match
for _, src := range m.Srcs {
if keep(src.IP) {
retm.Srcs = append(retm.Srcs, src)
}
}
for _, dst := range m.Dsts {
if keep(dst.Net.IP) {
retm.Dsts = append(retm.Dsts, dst)
}
}
if len(retm.Srcs) > 0 && len(retm.Dsts) > 0 {
ret = append(ret, retm)
}
}
return ret
}
func maybeHexdump(flag RunFlags, b []byte) string { func maybeHexdump(flag RunFlags, b []byte) string {
if flag == 0 { if flag == 0 {
return "" return ""
@ -229,19 +245,17 @@ func (f *Filter) CheckTCP(srcIP, dstIP netaddr.IP, dstPort uint16) Response {
return Drop return Drop
case srcIP.Is4(): case srcIP.Is4():
pkt.IPVersion = 4 pkt.IPVersion = 4
pkt.SrcIP4 = packet.IP4FromNetaddr(srcIP)
pkt.DstIP4 = packet.IP4FromNetaddr(dstIP)
case srcIP.Is6(): case srcIP.Is6():
pkt.IPVersion = 6 pkt.IPVersion = 6
pkt.SrcIP6 = packet.IP6FromNetaddr(srcIP)
pkt.DstIP6 = packet.IP6FromNetaddr(dstIP)
default: default:
panic("unreachable") panic("unreachable")
} }
pkt.Src.IP = srcIP
pkt.Dst.IP = dstIP
pkt.IPProto = packet.TCP pkt.IPProto = packet.TCP
pkt.TCPFlags = packet.TCPSyn pkt.TCPFlags = packet.TCPSyn
pkt.SrcPort = 0 pkt.Src.Port = 0
pkt.DstPort = dstPort pkt.Dst.Port = dstPort
return f.RunIn(pkt, 0) return f.RunIn(pkt, 0)
} }
@ -287,7 +301,7 @@ func (f *Filter) runIn4(q *packet.Parsed) (r Response, why string) {
// A compromised peer could try to send us packets for // A compromised peer could try to send us packets for
// destinations we didn't explicitly advertise. This check is to // destinations we didn't explicitly advertise. This check is to
// prevent that. // prevent that.
if !ip4InList(q.DstIP4, f.local4) { if !ipInList(q.Dst.IP, f.local4) {
return Drop, "destination not allowed" return Drop, "destination not allowed"
} }
@ -320,11 +334,11 @@ func (f *Filter) runIn4(q *packet.Parsed) (r Response, why string) {
return Accept, "tcp ok" return Accept, "tcp ok"
} }
case packet.UDP: case packet.UDP:
t := tuple4{q.SrcIP4, q.DstIP4, q.SrcPort, q.DstPort} t := tuple{q.Src, q.Dst}
f.state4.mu.Lock() f.state.mu.Lock()
_, ok := f.state4.lru.Get(t) _, ok := f.state.lru.Get(t)
f.state4.mu.Unlock() f.state.mu.Unlock()
if ok { if ok {
return Accept, "udp cached" return Accept, "udp cached"
@ -342,7 +356,7 @@ func (f *Filter) runIn6(q *packet.Parsed) (r Response, why string) {
// A compromised peer could try to send us packets for // A compromised peer could try to send us packets for
// destinations we didn't explicitly advertise. This check is to // destinations we didn't explicitly advertise. This check is to
// prevent that. // prevent that.
if !ip6InList(q.DstIP6, f.local6) { if !ipInList(q.Dst.IP, f.local6) {
return Drop, "destination not allowed" return Drop, "destination not allowed"
} }
@ -375,11 +389,11 @@ func (f *Filter) runIn6(q *packet.Parsed) (r Response, why string) {
return Accept, "tcp ok" return Accept, "tcp ok"
} }
case packet.UDP: case packet.UDP:
t := tuple6{q.SrcIP6, q.DstIP6, q.SrcPort, q.DstPort} t := tuple{q.Src, q.Dst}
f.state6.mu.Lock() f.state.mu.Lock()
_, ok := f.state6.lru.Get(t) _, ok := f.state.lru.Get(t)
f.state6.mu.Unlock() f.state.mu.Unlock()
if ok { if ok {
return Accept, "udp cached" return Accept, "udp cached"
@ -399,20 +413,11 @@ func (f *Filter) runOut(q *packet.Parsed) (r Response, why string) {
return Accept, "ok out" return Accept, "ok out"
} }
switch q.IPVersion { t := tuple{q.Dst, q.Src}
case 4: var ti interface{} = t // allocate once, rather than twice inside mutex
t := tuple4{q.DstIP4, q.SrcIP4, q.DstPort, q.SrcPort} f.state.mu.Lock()
var ti interface{} = t // allocate once, rather than twice inside mutex f.state.lru.Add(ti, ti)
f.state4.mu.Lock() f.state.mu.Unlock()
f.state4.lru.Add(ti, ti)
f.state4.mu.Unlock()
case 6:
t := tuple6{q.DstIP6, q.SrcIP6, q.DstPort, q.SrcPort}
var ti interface{} = t // allocate once, rather than twice inside mutex
f.state6.mu.Lock()
f.state6.lru.Add(ti, ti)
f.state6.mu.Unlock()
}
return Accept, "ok out" return Accept, "ok out"
} }
@ -436,6 +441,8 @@ func (d direction) String() string {
} }
} }
var gcpDNSAddr = netaddr.IPv4(169, 254, 169, 254)
// pre runs the direction-agnostic filter logic. dir is only used for // pre runs the direction-agnostic filter logic. dir is only used for
// logging. // logging.
func (f *Filter) pre(q *packet.Parsed, rf RunFlags, dir direction) Response { func (f *Filter) pre(q *packet.Parsed, rf RunFlags, dir direction) Response {
@ -448,25 +455,13 @@ func (f *Filter) pre(q *packet.Parsed, rf RunFlags, dir direction) Response {
return Drop return Drop
} }
switch q.IPVersion { if q.Dst.IP.IsMulticast() {
case 4: f.logRateLimit(rf, q, dir, Drop, "multicast")
if q.DstIP4.IsMulticast() { return Drop
f.logRateLimit(rf, q, dir, Drop, "multicast") }
return Drop if q.Dst.IP.IsLinkLocalUnicast() && q.Dst.IP != gcpDNSAddr {
} f.logRateLimit(rf, q, dir, Drop, "link-local-unicast")
if q.DstIP4.IsMostLinkLocalUnicast() { return Drop
f.logRateLimit(rf, q, dir, Drop, "link-local-unicast")
return Drop
}
case 6:
if q.DstIP6.IsMulticast() {
f.logRateLimit(rf, q, dir, Drop, "multicast")
return Drop
}
if q.DstIP6.IsLinkLocalUnicast() {
f.logRateLimit(rf, q, dir, Drop, "link-local-unicast")
return Drop
}
} }
switch q.IPProto { switch q.IPProto {
@ -493,12 +488,5 @@ func omitDropLogging(p *packet.Parsed, dir direction) bool {
return false return false
} }
switch p.IPVersion { return p.Dst.IP.IsMulticast() || (p.Dst.IP.IsLinkLocalUnicast() && p.Dst.IP != gcpDNSAddr) || p.IPProto == packet.IGMP
case 4:
return p.DstIP4.IsMulticast() || p.DstIP4.IsMostLinkLocalUnicast() || p.IPProto == packet.IGMP
case 6:
return p.DstIP6.IsMulticast() || p.DstIP6.IsLinkLocalUnicast()
default:
return false
}
} }

View File

@ -94,9 +94,9 @@ type InOut struct {
if test.p.IPProto == packet.TCP { if test.p.IPProto == packet.TCP {
var got Response var got Response
if test.p.IPVersion == 4 { if test.p.IPVersion == 4 {
got = acl.CheckTCP(test.p.SrcIP4.Netaddr(), test.p.DstIP4.Netaddr(), test.p.DstPort) got = acl.CheckTCP(test.p.Src.IP, test.p.Dst.IP, test.p.Dst.Port)
} else { } else {
got = acl.CheckTCP(test.p.SrcIP6.Netaddr(), test.p.DstIP6.Netaddr(), test.p.DstPort) got = acl.CheckTCP(test.p.Src.IP, test.p.Dst.IP, test.p.Dst.Port)
} }
if test.want != got { if test.want != got {
t.Errorf("#%d CheckTCP got=%v want=%v packet:%v", i, got, test.want, test.p) t.Errorf("#%d CheckTCP got=%v want=%v packet:%v", i, got, test.want, test.p)
@ -345,19 +345,19 @@ func TestOmitDropLogging(t *testing.T) {
}, },
{ {
name: "v4_multicast_out_low", name: "v4_multicast_out_low",
pkt: &packet.Parsed{IPVersion: 4, DstIP4: mustIP4("224.0.0.0")}, pkt: &packet.Parsed{IPVersion: 4, Dst: mustIPPort("224.0.0.0:0")},
dir: out, dir: out,
want: true, want: true,
}, },
{ {
name: "v4_multicast_out_high", name: "v4_multicast_out_high",
pkt: &packet.Parsed{IPVersion: 4, DstIP4: mustIP4("239.255.255.255")}, pkt: &packet.Parsed{IPVersion: 4, Dst: mustIPPort("239.255.255.255:0")},
dir: out, dir: out,
want: true, want: true,
}, },
{ {
name: "v4_link_local_unicast", name: "v4_link_local_unicast",
pkt: &packet.Parsed{IPVersion: 4, DstIP4: mustIP4("169.254.1.2")}, pkt: &packet.Parsed{IPVersion: 4, Dst: mustIPPort("169.254.1.2:0")},
dir: out, dir: out,
want: true, want: true,
}, },
@ -387,18 +387,16 @@ func parsed(proto packet.IPProto, src, dst string, sport, dport uint16) packet.P
var ret packet.Parsed var ret packet.Parsed
ret.Decode(dummyPacket) ret.Decode(dummyPacket)
ret.IPProto = proto ret.IPProto = proto
ret.SrcPort = sport ret.Src.IP = sip
ret.DstPort = dport ret.Src.Port = sport
ret.Dst.IP = dip
ret.Dst.Port = dport
ret.TCPFlags = packet.TCPSyn ret.TCPFlags = packet.TCPSyn
if sip.Is4() { if sip.Is4() {
ret.IPVersion = 4 ret.IPVersion = 4
ret.SrcIP4 = packet.IP4FromNetaddr(sip)
ret.DstIP4 = packet.IP4FromNetaddr(dip)
} else { } else {
ret.IPVersion = 6 ret.IPVersion = 6
ret.SrcIP6 = packet.IP6FromNetaddr(sip)
ret.DstIP6 = packet.IP6FromNetaddr(dip)
} }
return ret return ret
@ -407,8 +405,8 @@ func parsed(proto packet.IPProto, src, dst string, sport, dport uint16) packet.P
func raw6(proto packet.IPProto, src, dst string, sport, dport uint16, trimLen int) []byte { func raw6(proto packet.IPProto, src, dst string, sport, dport uint16, trimLen int) []byte {
u := packet.UDP6Header{ u := packet.UDP6Header{
IP6Header: packet.IP6Header{ IP6Header: packet.IP6Header{
SrcIP: packet.IP6FromNetaddr(mustIP(src)), Src: mustIP(src),
DstIP: packet.IP6FromNetaddr(mustIP(dst)), Dst: mustIP(dst),
}, },
SrcPort: sport, SrcPort: sport,
DstPort: dport, DstPort: dport,
@ -436,8 +434,8 @@ func raw6(proto packet.IPProto, src, dst string, sport, dport uint16, trimLen in
func raw4(proto packet.IPProto, src, dst string, sport, dport uint16, trimLength int) []byte { func raw4(proto packet.IPProto, src, dst string, sport, dport uint16, trimLength int) []byte {
u := packet.UDP4Header{ u := packet.UDP4Header{
IP4Header: packet.IP4Header{ IP4Header: packet.IP4Header{
SrcIP: packet.IP4FromNetaddr(mustIP(src)), Src: mustIP(src),
DstIP: packet.IP4FromNetaddr(mustIP(dst)), Dst: mustIP(dst),
}, },
SrcPort: sport, SrcPort: sport,
DstPort: dport, DstPort: dport,
@ -488,12 +486,12 @@ func parseHexPkt(t *testing.T, h string) *packet.Parsed {
return p return p
} }
func mustIP4(s string) packet.IP4 { func mustIPPort(s string) netaddr.IPPort {
ip, err := netaddr.ParseIP(s) ipp, err := netaddr.ParseIPPort(s)
if err != nil { if err != nil {
panic(err) panic(err)
} }
return packet.IP4FromNetaddr(ip) return ipp
} }
func pfx(strs ...string) (ret []netaddr.IPPrefix) { func pfx(strs ...string) (ret []netaddr.IPPrefix) {

View File

@ -9,6 +9,7 @@
"strings" "strings"
"inet.af/netaddr" "inet.af/netaddr"
"tailscale.com/net/packet"
) )
// PortRange is a range of TCP and UDP ports. // PortRange is a range of TCP and UDP ports.
@ -71,3 +72,46 @@ func (m Match) String() string {
} }
return fmt.Sprintf("%v=>%v", ss, ds) return fmt.Sprintf("%v=>%v", ss, ds)
} }
type matches []Match
func (ms matches) match(q *packet.Parsed) bool {
for _, m := range ms {
if !ipInList(q.Src.IP, m.Srcs) {
continue
}
for _, dst := range m.Dsts {
if !dst.Net.Contains(q.Dst.IP) {
continue
}
if !dst.Ports.contains(q.Dst.Port) {
continue
}
return true
}
}
return false
}
func (ms matches) matchIPsOnly(q *packet.Parsed) bool {
for _, m := range ms {
if !ipInList(q.Src.IP, m.Srcs) {
continue
}
for _, dst := range m.Dsts {
if dst.Net.Contains(q.Dst.IP) {
return true
}
}
}
return false
}
func ipInList(ip netaddr.IP, netlist []netaddr.IPPrefix) bool {
for _, net := range netlist {
if net.Contains(ip) {
return true
}
}
return false
}

View File

@ -1,151 +0,0 @@
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package filter
import (
"fmt"
"math/bits"
"strings"
"inet.af/netaddr"
"tailscale.com/net/packet"
)
type net4 struct {
ip packet.IP4
mask packet.IP4
}
func net4FromIPPrefix(pfx netaddr.IPPrefix) net4 {
if !pfx.IP.Is4() {
panic("net4FromIPPrefix given non-ipv4 prefix")
}
return net4{
ip: packet.IP4FromNetaddr(pfx.IP),
mask: netmask4(pfx.Bits),
}
}
func nets4FromIPPrefixes(pfxs []netaddr.IPPrefix) (ret []net4) {
for _, pfx := range pfxs {
if pfx.IP.Is4() {
ret = append(ret, net4FromIPPrefix(pfx))
}
}
return ret
}
func (n net4) Contains(ip packet.IP4) bool {
return (n.ip & n.mask) == (ip & n.mask)
}
func (n net4) Bits() int {
return 32 - bits.TrailingZeros32(uint32(n.mask))
}
func (n net4) String() string {
b := n.Bits()
if b == 32 {
return n.ip.String()
} else if b == 0 {
return "*"
} else {
return fmt.Sprintf("%s/%d", n.ip, b)
}
}
type npr4 struct {
net net4
ports PortRange
}
func (npr npr4) String() string {
return fmt.Sprintf("%s:%s", npr.net, npr.ports)
}
type match4 struct {
srcs []net4
dsts []npr4
}
type matches4 []match4
func (ms matches4) String() string {
var b strings.Builder
for _, m := range ms {
fmt.Fprintf(&b, "%s => %s\n", m.srcs, m.dsts)
}
return b.String()
}
func newMatches4(ms []Match) (ret matches4) {
for _, m := range ms {
var m4 match4
for _, src := range m.Srcs {
if src.IP.Is4() {
m4.srcs = append(m4.srcs, net4FromIPPrefix(src))
}
}
for _, dst := range m.Dsts {
if dst.Net.IP.Is4() {
m4.dsts = append(m4.dsts, npr4{net4FromIPPrefix(dst.Net), dst.Ports})
}
}
if len(m4.srcs) > 0 && len(m4.dsts) > 0 {
ret = append(ret, m4)
}
}
return ret
}
// match returns whether q's source IP and destination IP:port match
// any of ms.
func (ms matches4) match(q *packet.Parsed) bool {
for _, m := range ms {
if !ip4InList(q.SrcIP4, m.srcs) {
continue
}
for _, dst := range m.dsts {
if !dst.net.Contains(q.DstIP4) {
continue
}
if !dst.ports.contains(q.DstPort) {
continue
}
return true
}
}
return false
}
// matchIPsOnly returns whether q's source and destination IP match
// any of ms.
func (ms matches4) matchIPsOnly(q *packet.Parsed) bool {
for _, m := range ms {
if !ip4InList(q.SrcIP4, m.srcs) {
continue
}
for _, dst := range m.dsts {
if dst.net.Contains(q.DstIP4) {
return true
}
}
}
return false
}
func netmask4(bits uint8) packet.IP4 {
b := ^uint32((1 << (32 - bits)) - 1)
return packet.IP4(b)
}
func ip4InList(ip packet.IP4, netlist []net4) bool {
for _, net := range netlist {
if net.Contains(ip) {
return true
}
}
return false
}

View File

@ -1,171 +0,0 @@
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package filter
import (
"fmt"
"math/bits"
"strings"
"inet.af/netaddr"
"tailscale.com/net/packet"
)
type net6 struct {
ip packet.IP6
mask packet.IP6
}
func net6FromIPPrefix(pfx netaddr.IPPrefix) net6 {
if !pfx.IP.Is6() {
panic("net6FromIPPrefix given non-ipv6 prefix")
}
var mask packet.IP6
if pfx.Bits > 64 {
mask.Hi = ^uint64(0)
mask.Lo = (^uint64(0) << (128 - pfx.Bits))
} else {
mask.Hi = (^uint64(0) << (64 - pfx.Bits))
}
return net6{
ip: packet.IP6FromNetaddr(pfx.IP),
mask: mask,
}
}
func nets6FromIPPrefixes(pfxs []netaddr.IPPrefix) (ret []net6) {
for _, pfx := range pfxs {
if pfx.IP.Is6() {
ret = append(ret, net6FromIPPrefix(pfx))
}
}
return ret
}
func (n net6) Contains(ip packet.IP6) bool {
// This is equivalent to the more straightforward implementation:
// ((n.ip.Hi & n.mask.Hi) == (ip.Hi & n.mask.Hi) &&
// (n.ip.Lo & n.mask.Lo) == (ip.Lo & n.mask.Lo))
//
// This implementation runs significantly faster because it
// eliminates branches and minimizes the required
// bit-twiddling.
a := (n.ip.Hi ^ ip.Hi) & n.mask.Hi
b := (n.ip.Lo ^ ip.Lo) & n.mask.Lo
return (a | b) == 0
}
func (n net6) Bits() int {
return 128 - bits.TrailingZeros64(n.mask.Hi) - bits.TrailingZeros64(n.mask.Lo)
}
func (n net6) String() string {
switch n.Bits() {
case 128:
return n.ip.String()
case 0:
return "*"
default:
return fmt.Sprintf("%s/%d", n.ip, n.Bits())
}
}
type npr6 struct {
net net6
ports PortRange
}
func (npr npr6) String() string {
return fmt.Sprintf("%s:%s", npr.net, npr.ports)
}
type match6 struct {
srcs []net6
dsts []npr6
}
type matches6 []match6
func (ms matches6) String() string {
var b strings.Builder
for _, m := range ms {
fmt.Fprintf(&b, "%s => %s\n", m.srcs, m.dsts)
}
return b.String()
}
func newMatches6(ms []Match) (ret matches6) {
for _, m := range ms {
var m6 match6
for _, src := range m.Srcs {
if src.IP.Is6() {
m6.srcs = append(m6.srcs, net6FromIPPrefix(src))
}
}
for _, dst := range m.Dsts {
if dst.Net.IP.Is6() {
m6.dsts = append(m6.dsts, npr6{net6FromIPPrefix(dst.Net), dst.Ports})
}
}
if len(m6.srcs) > 0 && len(m6.dsts) > 0 {
ret = append(ret, m6)
}
}
return ret
}
func (ms matches6) match(q *packet.Parsed) bool {
outer:
for i := range ms {
srcs := ms[i].srcs
for j := range srcs {
if srcs[j].Contains(q.SrcIP6) {
dsts := ms[i].dsts
for k := range dsts {
if dsts[k].net.Contains(q.DstIP6) && dsts[k].ports.contains(q.DstPort) {
return true
}
}
// We hit on src, but missed on all
// dsts. No need to try other srcs,
// they'll never fully match.
continue outer
}
}
}
return false
}
func (ms matches6) matchIPsOnly(q *packet.Parsed) bool {
outer:
for i := range ms {
srcs := ms[i].srcs
for j := range srcs {
if srcs[j].Contains(q.SrcIP6) {
dsts := ms[i].dsts
for k := range dsts {
if dsts[k].net.Contains(q.DstIP6) {
return true
}
}
// We hit on src, but missed on all
// dsts. No need to try other srcs,
// they'll never fully match.
continue outer
}
}
}
return false
}
func ip6InList(ip packet.IP6, netlist []net6) bool {
for _, net := range netlist {
if net.Contains(ip) {
return true
}
}
return false
}

View File

@ -1,37 +0,0 @@
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package filter
import "testing"
// Verifies that the fast bit-twiddling implementation of Contains
// works the same as the easy-to-read implementation. Since we can't
// sensibly check it on 128 bits, the test runs over 4-bit
// "IPs". Bit-twiddling is the same at any width, so this adequately
// proves that the implementations are equivalent.
func TestOptimizedContains(t *testing.T) {
for ipHi := 0; ipHi < 0xf; ipHi++ {
for ipLo := 0; ipLo < 0xf; ipLo++ {
for nIPHi := 0; nIPHi < 0xf; nIPHi++ {
for nIPLo := 0; nIPLo < 0xf; nIPLo++ {
for maskHi := 0; maskHi < 0xf; maskHi++ {
for maskLo := 0; maskLo < 0xf; maskLo++ {
a := (nIPHi ^ ipHi) & maskHi
b := (nIPLo ^ ipLo) & maskLo
got := (a | b) == 0
want := ((nIPHi&maskHi) == (ipHi&maskHi) && (nIPLo&maskLo) == (ipLo&maskLo))
if got != want {
t.Errorf("mask %1x%1x/%1x%1x %1x%1x got=%v want=%v", nIPHi, nIPLo, maskHi, maskLo, ipHi, ipLo, got, want)
}
}
}
}
}
}
}
}

View File

@ -16,6 +16,7 @@
"github.com/tailscale/wireguard-go/device" "github.com/tailscale/wireguard-go/device"
"github.com/tailscale/wireguard-go/tun" "github.com/tailscale/wireguard-go/tun"
"inet.af/netaddr"
"tailscale.com/net/packet" "tailscale.com/net/packet"
"tailscale.com/types/logger" "tailscale.com/types/logger"
"tailscale.com/wgengine/filter" "tailscale.com/wgengine/filter"
@ -67,8 +68,7 @@ type TUN struct {
lastActivityAtomic int64 // unix seconds of last send or receive lastActivityAtomic int64 // unix seconds of last send or receive
destIPActivity4 atomic.Value // of map[packet.IP4]func() destIPActivity atomic.Value // of map[netaddr.IP]func()
destIPActivity6 atomic.Value // of map[packet.IP6]func()
// buffer stores the oldest unconsumed packet from tdev. // buffer stores the oldest unconsumed packet from tdev.
// It is made a static buffer in order to avoid allocations. // It is made a static buffer in order to avoid allocations.
@ -137,9 +137,8 @@ func WrapTUN(logf logger.Logf, tdev tun.Device) *TUN {
// destination (the map keys). // destination (the map keys).
// //
// The map ownership passes to the TUN. It must be non-nil. // The map ownership passes to the TUN. It must be non-nil.
func (t *TUN) SetDestIPActivityFuncs(m4 map[packet.IP4]func(), m6 map[packet.IP6]func()) { func (t *TUN) SetDestIPActivityFuncs(m map[netaddr.IP]func()) {
t.destIPActivity4.Store(m4) t.destIPActivity.Store(m)
t.destIPActivity6.Store(m6)
} }
func (t *TUN) Close() error { func (t *TUN) Close() error {
@ -284,18 +283,9 @@ func (t *TUN) Read(buf []byte, offset int) (int, error) {
defer parsedPacketPool.Put(p) defer parsedPacketPool.Put(p)
p.Decode(buf[offset : offset+n]) p.Decode(buf[offset : offset+n])
switch p.IPVersion { if m, ok := t.destIPActivity.Load().(map[netaddr.IP]func()); ok {
case 4: if fn := m[p.Dst.IP]; fn != nil {
if m, ok := t.destIPActivity4.Load().(map[packet.IP4]func()); ok { fn()
if fn := m[p.DstIP4]; fn != nil {
fn()
}
}
case 6:
if m, ok := t.destIPActivity6.Load().(map[packet.IP6]func()); ok {
if fn := m[p.DstIP6]; fn != nil {
fn()
}
} }
} }

View File

@ -20,12 +20,20 @@
"tailscale.com/wgengine/filter" "tailscale.com/wgengine/filter"
) )
func udp(src, dst packet.IP4, sport, dport uint16) []byte { func udp4(src, dst string, sport, dport uint16) []byte {
sip, err := netaddr.ParseIP(src)
if err != nil {
panic(err)
}
dip, err := netaddr.ParseIP(dst)
if err != nil {
panic(err)
}
header := &packet.UDP4Header{ header := &packet.UDP4Header{
IP4Header: packet.IP4Header{ IP4Header: packet.IP4Header{
SrcIP: src, Src: sip,
DstIP: dst, Dst: dip,
IPID: 0, IPID: 0,
}, },
SrcPort: sport, SrcPort: sport,
DstPort: dport, DstPort: dport,
@ -252,12 +260,12 @@ func TestFilter(t *testing.T) {
}{ }{
{"junk_in", in, true, []byte("\x45not a valid IPv4 packet")}, {"junk_in", in, true, []byte("\x45not a valid IPv4 packet")},
{"junk_out", out, true, []byte("\x45not a valid IPv4 packet")}, {"junk_out", out, true, []byte("\x45not a valid IPv4 packet")},
{"bad_port_in", in, true, udp(0x05060708, 0x01020304, 22, 22)}, {"bad_port_in", in, true, udp4("5.6.7.8", "1.2.3.4", 22, 22)},
{"bad_port_out", out, false, udp(0x01020304, 0x05060708, 22, 22)}, {"bad_port_out", out, false, udp4("1.2.3.4", "5.6.7.8", 22, 22)},
{"bad_ip_in", in, true, udp(0x08010101, 0x01020304, 89, 89)}, {"bad_ip_in", in, true, udp4("8.1.1.1", "1.2.3.4", 89, 89)},
{"bad_ip_out", out, false, udp(0x01020304, 0x08010101, 98, 98)}, {"bad_ip_out", out, false, udp4("1.2.3.4", "8.1.1.1", 98, 98)},
{"good_packet_in", in, false, udp(0x05060708, 0x01020304, 89, 89)}, {"good_packet_in", in, false, udp4("5.6.7.8", "1.2.3.4", 89, 89)},
{"good_packet_out", out, false, udp(0x01020304, 0x05060708, 98, 98)}, {"good_packet_out", out, false, udp4("1.2.3.4", "5.6.7.8", 98, 98)},
} }
// A reader on the other end of the TUN. // A reader on the other end of the TUN.
@ -337,7 +345,7 @@ func BenchmarkWrite(b *testing.B) {
ftun, tun := newFakeTUN(b.Logf, true) ftun, tun := newFakeTUN(b.Logf, true)
defer tun.Close() defer tun.Close()
packet := udp(0x05060708, 0x01020304, 89, 89) packet := udp4("5.6.7.8", "1.2.3.4", 89, 89)
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
_, err := ftun.Write(packet, 0) _, err := ftun.Write(packet, 0)
if err != nil { if err != nil {

View File

@ -8,7 +8,6 @@
"bufio" "bufio"
"bytes" "bytes"
"context" "context"
"encoding/binary"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@ -58,10 +57,9 @@
// discovery. // discovery.
const minimalMTU = 1280 const minimalMTU = 1280
const ( const magicDNSPort = 53
magicDNSIP = 0x64646464 // 100.100.100.100
magicDNSPort = 53 var magicDNSIP = netaddr.IPv4(100, 100, 100, 100)
)
// Lazy wireguard-go configuration parameters. // Lazy wireguard-go configuration parameters.
const ( const (
@ -99,19 +97,17 @@ type userspaceEngine struct {
// localAddrs is the set of IP addresses assigned to the local // localAddrs is the set of IP addresses assigned to the local
// tunnel interface. It's used to reflect local packets // tunnel interface. It's used to reflect local packets
// incorrectly sent to us. // incorrectly sent to us.
localAddrs atomic.Value // of map[packet.IP4]bool localAddrs atomic.Value // of map[netaddr.IP]bool
wgLock sync.Mutex // serializes all wgdev operations; see lock order comment below wgLock sync.Mutex // serializes all wgdev operations; see lock order comment below
lastCfgFull wgcfg.Config lastCfgFull wgcfg.Config
lastRouterSig string // of router.Config lastRouterSig string // of router.Config
lastEngineSigFull string // of full wireguard config lastEngineSigFull string // of full wireguard config
lastEngineSigTrim string // of trimmed wireguard config lastEngineSigTrim string // of trimmed wireguard config
recvActivityAt map[tailcfg.DiscoKey]time.Time recvActivityAt map[tailcfg.DiscoKey]time.Time
trimmedDisco map[tailcfg.DiscoKey]bool // set of disco keys of peers currently excluded from wireguard config trimmedDisco map[tailcfg.DiscoKey]bool // set of disco keys of peers currently excluded from wireguard config
sentActivityAt4 map[packet.IP4]*int64 // value is atomic int64 of unixtime sentActivityAt map[netaddr.IP]*int64 // value is atomic int64 of unixtime
destIPActivityFuncs4 map[packet.IP4]func() destIPActivityFuncs map[netaddr.IP]func()
sentActivityAt6 map[packet.IP6]*int64 // value is atomic int64 of unixtime
destIPActivityFuncs6 map[packet.IP6]func()
mu sync.Mutex // guards following; see lock order comment below mu sync.Mutex // guards following; see lock order comment below
closing bool // Close was called (even if we're still closing) closing bool // Close was called (even if we're still closing)
@ -208,7 +204,7 @@ func newUserspaceEngineAdvanced(conf EngineConfig) (_ Engine, reterr error) {
resolver: tsdns.NewResolver(rconf), resolver: tsdns.NewResolver(rconf),
pingers: make(map[wgcfg.Key]*pinger), pingers: make(map[wgcfg.Key]*pinger),
} }
e.localAddrs.Store(map[packet.IP4]bool{}) e.localAddrs.Store(map[netaddr.IP]bool{})
e.linkState, _ = getLinkState() e.linkState, _ = getLinkState()
logf("link state: %+v", e.linkState) logf("link state: %+v", e.linkState)
@ -399,7 +395,7 @@ func (e *userspaceEngine) handleLocalPackets(p *packet.Parsed, t *tstun.TUN) fil
return filter.Drop return filter.Drop
} }
if (runtime.GOOS == "darwin" || runtime.GOOS == "ios") && e.isLocalAddr(p.DstIP4) { if (runtime.GOOS == "darwin" || runtime.GOOS == "ios") && e.isLocalAddr(p.Dst.IP) {
// macOS NetworkExtension directs packets destined to the // macOS NetworkExtension directs packets destined to the
// tunnel's local IP address into the tunnel, instead of // tunnel's local IP address into the tunnel, instead of
// looping back within the kernel network stack. We have to // looping back within the kernel network stack. We have to
@ -412,8 +408,8 @@ func (e *userspaceEngine) handleLocalPackets(p *packet.Parsed, t *tstun.TUN) fil
return filter.Accept return filter.Accept
} }
func (e *userspaceEngine) isLocalAddr(ip packet.IP4) bool { func (e *userspaceEngine) isLocalAddr(ip netaddr.IP) bool {
localAddrs, ok := e.localAddrs.Load().(map[packet.IP4]bool) localAddrs, ok := e.localAddrs.Load().(map[netaddr.IP]bool)
if !ok { if !ok {
e.logf("[unexpected] e.localAddrs was nil, can't check for loopback packet") e.logf("[unexpected] e.localAddrs was nil, can't check for loopback packet")
return false return false
@ -423,10 +419,10 @@ func (e *userspaceEngine) isLocalAddr(ip packet.IP4) bool {
// handleDNS is an outbound pre-filter resolving Tailscale domains. // handleDNS is an outbound pre-filter resolving Tailscale domains.
func (e *userspaceEngine) handleDNS(p *packet.Parsed, t *tstun.TUN) filter.Response { func (e *userspaceEngine) handleDNS(p *packet.Parsed, t *tstun.TUN) filter.Response {
if p.DstIP4 == magicDNSIP && p.DstPort == magicDNSPort && p.IPProto == packet.UDP { if p.Dst.IP == magicDNSIP && p.Dst.Port == magicDNSPort && p.IPProto == packet.UDP {
request := tsdns.Packet{ request := tsdns.Packet{
Payload: append([]byte(nil), p.Payload()...), Payload: append([]byte(nil), p.Payload()...),
Addr: netaddr.IPPort{IP: p.SrcIP4.Netaddr(), Port: p.SrcPort}, Addr: netaddr.IPPort{IP: p.Src.IP, Port: p.Src.Port},
} }
err := e.resolver.EnqueueRequest(request) err := e.resolver.EnqueueRequest(request)
if err != nil { if err != nil {
@ -451,8 +447,8 @@ func (e *userspaceEngine) pollResolver() {
h := packet.UDP4Header{ h := packet.UDP4Header{
IP4Header: packet.IP4Header{ IP4Header: packet.IP4Header{
SrcIP: packet.IP4(magicDNSIP), Src: magicDNSIP,
DstIP: packet.IP4FromNetaddr(resp.Addr.IP), Dst: resp.Addr.IP,
}, },
SrcPort: magicDNSPort, SrcPort: magicDNSPort,
DstPort: resp.Addr.Port, DstPort: resp.Addr.Port,
@ -489,7 +485,7 @@ func (p *pinger) close() {
<-p.done <-p.done
} }
func (p *pinger) run(ctx context.Context, peerKey wgcfg.Key, ips []wgcfg.IP, srcIP packet.IP4) { func (p *pinger) run(ctx context.Context, peerKey wgcfg.Key, ips []wgcfg.IP, srcIP netaddr.IP) {
defer func() { defer func() {
p.e.mu.Lock() p.e.mu.Lock()
if p.e.pingers[peerKey] == p { if p.e.pingers[peerKey] == p {
@ -502,7 +498,7 @@ func (p *pinger) run(ctx context.Context, peerKey wgcfg.Key, ips []wgcfg.IP, src
header := packet.ICMP4Header{ header := packet.ICMP4Header{
IP4Header: packet.IP4Header{ IP4Header: packet.IP4Header{
SrcIP: srcIP, Src: srcIP,
}, },
Type: packet.ICMP4EchoRequest, Type: packet.ICMP4EchoRequest,
Code: packet.ICMP4NoCode, Code: packet.ICMP4NoCode,
@ -515,7 +511,7 @@ func (p *pinger) run(ctx context.Context, peerKey wgcfg.Key, ips []wgcfg.IP, src
const stopAfter = 3 * time.Second const stopAfter = 3 * time.Second
start := time.Now() start := time.Now()
var dstIPs []packet.IP4 var dstIPs []netaddr.IP
for _, ip := range ips { for _, ip := range ips {
if ip.Is6() { if ip.Is6() {
// This code is only used for legacy (pre-discovery) // This code is only used for legacy (pre-discovery)
@ -524,7 +520,7 @@ func (p *pinger) run(ctx context.Context, peerKey wgcfg.Key, ips []wgcfg.IP, src
// work. // work.
continue continue
} }
dstIPs = append(dstIPs, packet.IP4FromNetaddr(netaddr.IPFrom16(ip.Addr))) dstIPs = append(dstIPs, netaddr.IPFrom16(ip.Addr))
} }
payload := []byte("magicsock_spray") // no meaning payload := []byte("magicsock_spray") // no meaning
@ -542,7 +538,7 @@ func (p *pinger) run(ctx context.Context, peerKey wgcfg.Key, ips []wgcfg.IP, src
return return
} }
for _, dstIP := range dstIPs { for _, dstIP := range dstIPs {
header.DstIP = dstIP header.Dst = dstIP
// InjectOutbound take ownership of the packet, so we allocate. // InjectOutbound take ownership of the packet, so we allocate.
b := packet.Generate(&header, payload) b := packet.Generate(&header, payload)
p.e.tundev.InjectOutbound(b) p.e.tundev.InjectOutbound(b)
@ -560,15 +556,15 @@ func (p *pinger) run(ctx context.Context, peerKey wgcfg.Key, ips []wgcfg.IP, src
// have advertised discovery keys. // have advertised discovery keys.
func (e *userspaceEngine) pinger(peerKey wgcfg.Key, ips []wgcfg.IP) { func (e *userspaceEngine) pinger(peerKey wgcfg.Key, ips []wgcfg.IP) {
e.logf("[v1] generating initial ping traffic to %s (%v)", peerKey.ShortString(), ips) e.logf("[v1] generating initial ping traffic to %s (%v)", peerKey.ShortString(), ips)
var srcIP packet.IP4 var srcIP netaddr.IP
e.wgLock.Lock() e.wgLock.Lock()
if len(e.lastCfgFull.Addresses) > 0 { if len(e.lastCfgFull.Addresses) > 0 {
srcIP = packet.IP4FromNetaddr(netaddr.IPFrom16(e.lastCfgFull.Addresses[0].IP.Addr)) srcIP = netaddr.IPFrom16(e.lastCfgFull.Addresses[0].IP.Addr)
} }
e.wgLock.Unlock() e.wgLock.Unlock()
if srcIP == 0 { if srcIP.IsZero() {
e.logf("generating initial ping traffic: no source IP") e.logf("generating initial ping traffic: no source IP")
return return
} }
@ -694,17 +690,8 @@ func (e *userspaceEngine) isActiveSince(dk tailcfg.DiscoKey, ip wgcfg.IP, t time
if e.recvActivityAt[dk].After(t) { if e.recvActivityAt[dk].After(t) {
return true return true
} }
var ( pip := netaddr.IPFrom16(ip.Addr)
timePtr *int64 timePtr, ok := e.sentActivityAt[pip]
ok bool
)
if ip.Is4() {
pip := packet.IP4(binary.BigEndian.Uint32(ip.Addr[12:]))
timePtr, ok = e.sentActivityAt4[pip]
} else {
pip := packet.IP6FromRaw16(ip.Addr)
timePtr, ok = e.sentActivityAt6[pip]
}
if !ok { if !ok {
return false return false
} }
@ -845,14 +832,10 @@ func (e *userspaceEngine) updateActivityMapsLocked(trackDisco []tailcfg.DiscoKey
} }
e.recvActivityAt = mr e.recvActivityAt = mr
oldTime4 := e.sentActivityAt4 oldTime := e.sentActivityAt
e.sentActivityAt4 = make(map[packet.IP4]*int64, len(oldTime4)) e.sentActivityAt = make(map[netaddr.IP]*int64, len(oldTime))
oldFunc4 := e.destIPActivityFuncs4 oldFunc := e.destIPActivityFuncs
e.destIPActivityFuncs4 = make(map[packet.IP4]func(), len(oldFunc4)) e.destIPActivityFuncs = make(map[netaddr.IP]func(), len(oldFunc))
oldTime6 := e.sentActivityAt6
e.sentActivityAt6 = make(map[packet.IP6]*int64, len(oldTime6))
oldFunc6 := e.destIPActivityFuncs6
e.destIPActivityFuncs6 = make(map[packet.IP6]func(), len(oldFunc6))
updateFn := func(timePtr *int64) func() { updateFn := func(timePtr *int64) func() {
return func() { return func() {
@ -877,35 +860,20 @@ func (e *userspaceEngine) updateActivityMapsLocked(trackDisco []tailcfg.DiscoKey
} }
for _, wip := range trackIPs { for _, wip := range trackIPs {
if wip.Is4() { pip := netaddr.IPFrom16(wip.Addr)
pip := packet.IP4(binary.BigEndian.Uint32(wip.Addr[12:])) timePtr := oldTime[pip]
timePtr := oldTime4[pip] if timePtr == nil {
if timePtr == nil { timePtr = new(int64)
timePtr = new(int64)
}
e.sentActivityAt4[pip] = timePtr
fn := oldFunc4[pip]
if fn == nil {
fn = updateFn(timePtr)
}
e.destIPActivityFuncs4[pip] = fn
} else {
pip := packet.IP6FromRaw16(wip.Addr)
timePtr := oldTime6[pip]
if timePtr == nil {
timePtr = new(int64)
}
e.sentActivityAt6[pip] = timePtr
fn := oldFunc6[pip]
if fn == nil {
fn = updateFn(timePtr)
}
e.destIPActivityFuncs6[pip] = fn
} }
e.sentActivityAt[pip] = timePtr
fn := oldFunc[pip]
if fn == nil {
fn = updateFn(timePtr)
}
e.destIPActivityFuncs[pip] = fn
} }
e.tundev.SetDestIPActivityFuncs(e.destIPActivityFuncs4, e.destIPActivityFuncs6) e.tundev.SetDestIPActivityFuncs(e.destIPActivityFuncs)
} }
func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config) error { func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config) error {
@ -913,13 +881,9 @@ func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config)
panic("routerCfg must not be nil") panic("routerCfg must not be nil")
} }
localAddrs := map[packet.IP4]bool{} localAddrs := map[netaddr.IP]bool{}
for _, addr := range routerCfg.LocalAddrs { for _, addr := range routerCfg.LocalAddrs {
// TODO: ipv6 localAddrs[addr.IP] = true
if !addr.IP.Is4() {
continue
}
localAddrs[packet.IP4FromNetaddr(addr.IP)] = true
} }
e.localAddrs.Store(localAddrs) e.localAddrs.Store(localAddrs)