net/netaddr: start migrating to net/netip via new netaddr adapter package

Updates #5162

Change-Id: Id7bdec303b25471f69d542f8ce43805328d56c12
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick
2022-07-24 20:08:42 -07:00
committed by Brad Fitzpatrick
parent 7b1a91dfd3
commit 7eaf5e509f
191 changed files with 1009 additions and 888 deletions

View File

@@ -6,10 +6,11 @@ package natlab
import (
"fmt"
"net/netip"
"sync"
"time"
"inet.af/netaddr"
"tailscale.com/net/netaddr"
)
// FirewallType is the type of filtering a stateful firewall
@@ -52,7 +53,7 @@ func (s FirewallType) key(src, dst netaddr.IPPort) fwKey {
switch s {
case EndpointIndependentFirewall:
case AddressDependentFirewall:
k.dst = k.dst.WithIP(dst.IP())
k.dst = netip.AddrPortFrom(dst.Addr(), k.dst.Port())
case AddressAndPortDependentFirewall:
k.dst = dst
default:

View File

@@ -8,10 +8,11 @@ import (
"context"
"fmt"
"net"
"net/netip"
"sync"
"time"
"inet.af/netaddr"
"tailscale.com/net/netaddr"
)
// mapping is the state of an allocated NAT session.
@@ -62,7 +63,7 @@ func (t NATType) key(src, dst netaddr.IPPort) natKey {
switch t {
case EndpointIndependentNAT:
case AddressDependentNAT:
k.dst = k.dst.WithIP(dst.IP())
k.dst = netip.AddrPortFrom(dst.Addr(), k.dst.Port())
case AddressAndPortDependentNAT:
k.dst = dst
default:
@@ -171,7 +172,7 @@ func (n *SNAT44) HandleIn(p *Packet, iif *Interface) *Packet {
func (n *SNAT44) HandleForward(p *Packet, iif, oif *Interface) *Packet {
switch {
case oif == n.ExternalInterface:
if p.Src.IP() == oif.V4() {
if p.Src.Addr() == oif.V4() {
// Packet already NATed and is just retraversing Forward,
// don't touch it again.
return p

View File

@@ -17,13 +17,14 @@ import (
"fmt"
"math/rand"
"net"
"net/netip"
"os"
"sort"
"strconv"
"sync"
"time"
"inet.af/netaddr"
"tailscale.com/net/netaddr"
)
var traceOn, _ = strconv.ParseBool(os.Getenv("NATLAB_TRACE"))
@@ -131,11 +132,11 @@ func (n *Network) addMachineLocked(ip netaddr.IP, iface *Interface) {
func (n *Network) allocIPv4(iface *Interface) netaddr.IP {
n.mu.Lock()
defer n.mu.Unlock()
if n.Prefix4.IsZero() {
if !n.Prefix4.IsValid() {
return netaddr.IP{}
}
if n.lastV4.IsZero() {
n.lastV4 = n.Prefix4.IP()
if !n.lastV4.IsValid() {
n.lastV4 = n.Prefix4.Addr()
}
a := n.lastV4.As16()
addOne(&a, 15)
@@ -150,11 +151,11 @@ func (n *Network) allocIPv4(iface *Interface) netaddr.IP {
func (n *Network) allocIPv6(iface *Interface) netaddr.IP {
n.mu.Lock()
defer n.mu.Unlock()
if n.Prefix6.IsZero() {
if !n.Prefix6.IsValid() {
return netaddr.IP{}
}
if n.lastV6.IsZero() {
n.lastV6 = n.Prefix6.IP()
if !n.lastV6.IsValid() {
n.lastV6 = n.Prefix6.Addr()
}
a := n.lastV6.As16()
addOne(&a, 15)
@@ -180,21 +181,21 @@ func (n *Network) write(p *Packet) (num int, err error) {
n.mu.Lock()
defer n.mu.Unlock()
iface, ok := n.machine[p.Dst.IP()]
iface, ok := n.machine[p.Dst.Addr()]
if !ok {
// If the destination is within the network's authoritative
// range, no route to host.
if p.Dst.IP().Is4() && n.Prefix4.Contains(p.Dst.IP()) {
p.Trace("no route to %v", p.Dst.IP)
if p.Dst.Addr().Is4() && n.Prefix4.Contains(p.Dst.Addr()) {
p.Trace("no route to %v", p.Dst.Addr())
return len(p.Payload), nil
}
if p.Dst.IP().Is6() && n.Prefix6.Contains(p.Dst.IP()) {
p.Trace("no route to %v", p.Dst.IP)
if p.Dst.Addr().Is6() && n.Prefix6.Contains(p.Dst.Addr()) {
p.Trace("no route to %v", p.Dst.Addr())
return len(p.Payload), nil
}
if n.defaultGW == nil {
p.Trace("no route to %v", p.Dst.IP)
p.Trace("no route to %v", p.Dst.Addr())
return len(p.Payload), nil
}
iface = n.defaultGW
@@ -360,7 +361,7 @@ func (m *Machine) isLocalIP(ip netaddr.IP) bool {
func (m *Machine) deliverIncomingPacket(p *Packet, iface *Interface) {
p.setLocator("mach=%s if=%s", m.Name, iface.name)
if m.isLocalIP(p.Dst.IP()) {
if m.isLocalIP(p.Dst.Addr()) {
m.deliverLocalPacket(p, iface)
} else {
m.forwardPacket(p, iface)
@@ -388,7 +389,7 @@ func (m *Machine) deliverLocalPacket(p *Packet, iface *Interface) {
defer m.mu.Unlock()
conns := m.conns4
if p.Dst.IP().Is6() {
if p.Dst.Addr().Is6() {
conns = m.conns6
}
possibleDsts := []netaddr.IPPort{
@@ -414,7 +415,7 @@ func (m *Machine) deliverLocalPacket(p *Packet, iface *Interface) {
}
func (m *Machine) forwardPacket(p *Packet, iif *Interface) {
oif, err := m.interfaceForIP(p.Dst.IP())
oif, err := m.interfaceForIP(p.Dst.Addr())
if err != nil {
p.Trace("%v", err)
return
@@ -462,10 +463,10 @@ func (m *Machine) Attach(interfaceName string, n *Network) *Interface {
net: n,
name: interfaceName,
}
if ip := n.allocIPv4(f); !ip.IsZero() {
if ip := n.allocIPv4(f); ip.IsValid() {
f.ips = append(f.ips, ip)
}
if ip := n.allocIPv6(f); !ip.IsZero() {
if ip := n.allocIPv6(f); ip.IsValid() {
f.ips = append(f.ips, ip)
}
@@ -484,13 +485,13 @@ func (m *Machine) Attach(interfaceName string, n *Network) *Interface {
iface: f,
})
} else {
if !n.Prefix4.IsZero() {
if n.Prefix4.IsValid() {
m.routes = append(m.routes, routeEntry{
prefix: n.Prefix4,
iface: f,
})
}
if !n.Prefix6.IsZero() {
if n.Prefix6.IsValid() {
m.routes = append(m.routes, routeEntry{
prefix: n.Prefix6,
iface: f,
@@ -506,39 +507,39 @@ func (m *Machine) Attach(interfaceName string, n *Network) *Interface {
var (
v4unspec = netaddr.IPv4(0, 0, 0, 0)
v6unspec = netaddr.IPv6Unspecified()
v6unspec = netip.IPv6Unspecified()
)
func (m *Machine) writePacket(p *Packet) (n int, err error) {
p.setLocator("mach=%s", m.Name)
iface, err := m.interfaceForIP(p.Dst.IP())
iface, err := m.interfaceForIP(p.Dst.Addr())
if err != nil {
p.Trace("%v", err)
return 0, err
}
origSrcIP := p.Src.IP()
origSrcIP := p.Src.Addr()
switch {
case p.Src.IP() == v4unspec:
case p.Src.Addr() == v4unspec:
p.Trace("assigning srcIP=%s", iface.V4())
p.Src = p.Src.WithIP(iface.V4())
case p.Src.IP() == v6unspec:
p.Src = netip.AddrPortFrom(iface.V4(), p.Src.Port())
case p.Src.Addr() == v6unspec:
// v6unspec in Go means "any src, but match address families"
if p.Dst.IP().Is6() {
if p.Dst.Addr().Is6() {
p.Trace("assigning srcIP=%s", iface.V6())
p.Src = p.Src.WithIP(iface.V6())
} else if p.Dst.IP().Is4() {
p.Src = netip.AddrPortFrom(iface.V6(), p.Src.Port())
} else if p.Dst.Addr().Is4() {
p.Trace("assigning srcIP=%s", iface.V4())
p.Src = p.Src.WithIP(iface.V4())
p.Src = netip.AddrPortFrom(iface.V4(), p.Src.Port())
}
default:
if !iface.Contains(p.Src.IP()) {
err := fmt.Errorf("can't send to %v with src %v on interface %v", p.Dst.IP(), p.Src.IP(), iface)
if !iface.Contains(p.Src.Addr()) {
err := fmt.Errorf("can't send to %v with src %v on interface %v", p.Dst.Addr(), p.Src.Addr(), iface)
p.Trace("%v", err)
return 0, err
}
}
if p.Src.IP().IsZero() {
if !p.Src.Addr().IsValid() {
err := fmt.Errorf("no matching address for address family for %v", origSrcIP)
p.Trace("%v", err)
return 0, err
@@ -614,7 +615,7 @@ func (m *Machine) portInUseLocked(port uint16) bool {
func (m *Machine) registerConn4(c *conn) error {
m.mu.Lock()
defer m.mu.Unlock()
if c.ipp.IP().Is6() && c.ipp.IP() != v6unspec {
if c.ipp.Addr().Is6() && c.ipp.Addr() != v6unspec {
return fmt.Errorf("registerConn4 got IPv6 %s", c.ipp)
}
return registerConn(&m.conns4, c)
@@ -629,7 +630,7 @@ func (m *Machine) unregisterConn4(c *conn) {
func (m *Machine) registerConn6(c *conn) error {
m.mu.Lock()
defer m.mu.Unlock()
if c.ipp.IP().Is4() {
if c.ipp.Addr().Is4() {
return fmt.Errorf("registerConn6 got IPv4 %s", c.ipp)
}
return registerConn(&m.conns6, c)
@@ -804,7 +805,11 @@ func (c *conn) breakActiveReadsLocked() {
}
func (c *conn) LocalAddr() net.Addr {
return c.ipp.UDPAddr()
return &net.UDPAddr{
IP: c.ipp.Addr().AsSlice(),
Port: int(c.ipp.Port()),
Zone: c.ipp.Addr().Zone(),
}
}
func (c *conn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
@@ -824,7 +829,12 @@ func (c *conn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
case pkt := <-c.in:
n = copy(p, pkt.Payload)
pkt.Trace("PacketConn.ReadFrom")
return n, pkt.Src.UDPAddr(), nil
ua := &net.UDPAddr{
IP: pkt.Src.Addr().AsSlice(),
Port: int(pkt.Src.Port()),
Zone: pkt.Src.Addr().Zone(),
}
return n, ua, nil
case <-ctx.Done():
return 0, nil, context.DeadlineExceeded
}
@@ -835,6 +845,10 @@ func (c *conn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
if err != nil {
return 0, fmt.Errorf("bogus addr %T %q", addr, addr.String())
}
return c.WriteToUDPAddrPort(p, ipp)
}
func (c *conn) WriteToUDPAddrPort(p []byte, ipp netip.AddrPort) (n int, err error) {
pkt := &Packet{
Src: c.ipp,
Dst: ipp,

View File

@@ -7,10 +7,12 @@ package natlab
import (
"context"
"fmt"
"net"
"net/netip"
"testing"
"time"
"inet.af/netaddr"
"tailscale.com/net/netaddr"
"tailscale.com/tstest"
)
@@ -63,7 +65,7 @@ func TestSendPacket(t *testing.T) {
}
const msg = "some message"
if _, err := fooPC.WriteTo([]byte(msg), barAddr.UDPAddr()); err != nil {
if _, err := fooPC.WriteTo([]byte(msg), net.UDPAddrFromAddrPort(barAddr)); err != nil {
t.Fatal(err)
}
@@ -117,10 +119,10 @@ func TestMultiNetwork(t *testing.T) {
serverAddr := netaddr.IPPortFrom(ifServer.V4(), 789)
const msg1, msg2 = "hello", "world"
if _, err := natPC.WriteTo([]byte(msg1), clientAddr.UDPAddr()); err != nil {
if _, err := natPC.WriteTo([]byte(msg1), net.UDPAddrFromAddrPort(clientAddr)); err != nil {
t.Fatal(err)
}
if _, err := natPC.WriteTo([]byte(msg2), serverAddr.UDPAddr()); err != nil {
if _, err := natPC.WriteTo([]byte(msg2), net.UDPAddrFromAddrPort(serverAddr)); err != nil {
t.Fatal(err)
}
@@ -154,8 +156,8 @@ type trivialNAT struct {
}
func (n *trivialNAT) HandleIn(p *Packet, iface *Interface) *Packet {
if iface == n.wanIf && p.Dst.IP() == n.wanIf.V4() {
p.Dst = p.Dst.WithIP(n.clientIP)
if iface == n.wanIf && p.Dst.Addr() == n.wanIf.V4() {
p.Dst = netip.AddrPortFrom(n.clientIP, p.Dst.Port())
}
return p
}
@@ -167,13 +169,13 @@ func (n trivialNAT) HandleOut(p *Packet, iface *Interface) *Packet {
func (n *trivialNAT) HandleForward(p *Packet, iif, oif *Interface) *Packet {
// Outbound from LAN -> apply NAT, continue
if iif == n.lanIf && oif == n.wanIf {
if p.Src.IP() == n.clientIP {
p.Src = p.Src.WithIP(n.wanIf.V4())
if p.Src.Addr() == n.clientIP {
p.Src = netip.AddrPortFrom(n.wanIf.V4(), p.Src.Port())
}
return p
}
// Return traffic to LAN, allow if right dst.
if iif == n.wanIf && oif == n.lanIf && p.Dst.IP() == n.clientIP {
if iif == n.wanIf && oif == n.lanIf && p.Dst.Addr() == n.clientIP {
return p
}
// Else drop.
@@ -217,7 +219,7 @@ func TestPacketHandler(t *testing.T) {
const msg = "some message"
serverAddr := netaddr.IPPortFrom(ifServer.V4(), 456)
if _, err := clientPC.WriteTo([]byte(msg), serverAddr.UDPAddr()); err != nil {
if _, err := clientPC.WriteTo([]byte(msg), net.UDPAddrFromAddrPort(serverAddr)); err != nil {
t.Fatal(err)
}