net/tstun: use gaissmai/bart instead of tempfork/device

This implementation uses less memory than tempfork/device,
which helps avoid OOM conditions in the iOS VPN extension when
switching to a Tailnet with ExitNode routing enabled.

Updates tailscale/corp#18514

Signed-off-by: Percy Wegmann <percy@tailscale.com>
This commit is contained in:
Percy Wegmann
2024-03-22 17:23:53 -05:00
committed by Percy Wegmann
parent 1e7050e73a
commit 8b8b315258
12 changed files with 32 additions and 853 deletions

View File

@@ -1,90 +0,0 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Package table provides a Routing Table implementation which allows
// looking up the peer that should be used to route a given IP address.
package table
import (
"net/netip"
"tailscale.com/tempfork/device"
"tailscale.com/types/key"
"tailscale.com/util/mak"
)
// RoutingTableBuilder is a builder for a RoutingTable.
// It is not safe for concurrent use.
type RoutingTableBuilder struct {
// peers is a map from node public key to the peer that owns that key.
// It is only used to handle insertions and deletions.
peers map[key.NodePublic]*device.Peer
// prefixTrie is a trie of prefixes which facilitates looking up the
// peer that owns a given IP address.
prefixTrie *device.AllowedIPs
}
// Remove removes the given peer from the routing table.
func (t *RoutingTableBuilder) Remove(peer key.NodePublic) {
p, ok := t.peers[peer]
if !ok {
return
}
t.prefixTrie.RemoveByPeer(p)
delete(t.peers, peer)
}
// InsertOrReplace inserts the given peer and prefixes into the routing table.
func (t *RoutingTableBuilder) InsertOrReplace(peer key.NodePublic, pfxs ...netip.Prefix) {
p, ok := t.peers[peer]
if !ok {
p = device.NewPeer(peer)
mak.Set(&t.peers, peer, p)
} else {
t.prefixTrie.RemoveByPeer(p)
}
if len(pfxs) == 0 {
return
}
if t.prefixTrie == nil {
t.prefixTrie = new(device.AllowedIPs)
}
for _, pfx := range pfxs {
t.prefixTrie.Insert(pfx, p)
}
}
// Build returns a RoutingTable that can be used to look up peers.
// Build resets the RoutingTableBuilder to its zero value.
func (t *RoutingTableBuilder) Build() *RoutingTable {
pt := t.prefixTrie
t.prefixTrie = nil
t.peers = nil
return &RoutingTable{
prefixTrie: pt,
}
}
// RoutingTable provides a mapping from IP addresses to peers identified by
// their public node key. It is used to find the peer that should be used to
// route a given IP address.
// It is immutable after creation.
//
// It is safe for concurrent use.
type RoutingTable struct {
prefixTrie *device.AllowedIPs
}
// Lookup returns the peer that would be used to route the given IP address.
// If no peer is found, Lookup returns the zero value.
func (t *RoutingTable) Lookup(ip netip.Addr) (_ key.NodePublic, ok bool) {
if t == nil {
return key.NodePublic{}, false
}
p := t.prefixTrie.Lookup(ip.AsSlice())
if p == nil {
return key.NodePublic{}, false
}
return p.Key(), true
}

View File

@@ -18,6 +18,7 @@ import (
"sync/atomic"
"time"
"github.com/gaissmai/bart"
"github.com/tailscale/wireguard-go/device"
"github.com/tailscale/wireguard-go/tun"
"go4.org/mem"
@@ -27,7 +28,6 @@ import (
"tailscale.com/net/packet"
"tailscale.com/net/packet/checksum"
"tailscale.com/net/tsaddr"
"tailscale.com/net/tstun/table"
"tailscale.com/syncs"
"tailscale.com/tstime/mono"
"tailscale.com/types/ipproto"
@@ -611,15 +611,13 @@ type natFamilyConfig struct {
// peers will use to connect to this node.
listenAddrs views.Map[netip.Addr, struct{}] // masqAddr -> struct{}
// dstMasqAddrs is map of dst addresses to their respective MasqueradeAsIP
// addresses. The MasqueradeAsIP address is the address that should be used
// as the source address for packets to dst.
dstMasqAddrs views.Map[key.NodePublic, netip.Addr] // dst -> masqAddr
// dstMasqAddrs is the routing table used to map a given dst IP to the
// respective MasqueradeAsIP address. The MasqueradeAsIP address is the
// address that should be used as the source address for packets to dst.
dstMasqAddrs *bart.Table[netip.Addr]
// dstAddrToPeerKeyMapper is the routing table used to map a given dst IP to
// the peer key responsible for that IP.
// It only contains peers that require a MasqueradeAsIP address.
dstAddrToPeerKeyMapper *table.RoutingTable
// masqAddrCounts is a count of peers by MasqueradeAsIP.
masqAddrCounts map[netip.Addr]int
}
func (c *natFamilyConfig) String() string {
@@ -640,15 +638,10 @@ func (c *natFamilyConfig) String() string {
i++
return true
})
count := map[netip.Addr]int{}
c.dstMasqAddrs.Range(func(_ key.NodePublic, v netip.Addr) bool {
count[v]++
return true
})
i = 0
b.WriteString("], dstMasqAddrs: [")
for k, v := range count {
for k, v := range c.masqAddrCounts {
if i > 0 {
b.WriteString(", ")
}
@@ -682,14 +675,11 @@ func (c *natFamilyConfig) selectSrcIP(oldSrc, dst netip.Addr) netip.Addr {
if oldSrc != c.nativeAddr {
return oldSrc
}
p, ok := c.dstAddrToPeerKeyMapper.Lookup(dst)
eip, ok := c.dstMasqAddrs.Get(dst)
if !ok {
return oldSrc
}
if eip, ok := c.dstMasqAddrs.GetOk(p); ok {
return eip
}
return oldSrc
return eip
}
// natConfigFromWGConfig generates a natFamilyConfig from nm,
@@ -712,9 +702,9 @@ func natConfigFromWGConfig(wcfg *wgcfg.Config, addrFam ipproto.Version) *natFami
}
var (
rt table.RoutingTableBuilder
dstMasqAddrs map[key.NodePublic]netip.Addr
listenAddrs set.Set[netip.Addr]
rt bart.Table[netip.Addr]
masqAddrCounts = map[netip.Addr]int{}
listenAddrs set.Set[netip.Addr]
)
// When using an exit node that requires masquerading, we need to
@@ -747,17 +737,20 @@ func natConfigFromWGConfig(wcfg *wgcfg.Config, addrFam ipproto.Version) *natFami
} else {
continue
}
rt.InsertOrReplace(p.PublicKey, p.AllowedIPs...)
mak.Set(&dstMasqAddrs, p.PublicKey, addrToUse)
masqAddrCounts[addrToUse]++
for _, ip := range p.AllowedIPs {
rt.Insert(ip, addrToUse)
}
}
if len(listenAddrs) == 0 && len(dstMasqAddrs) == 0 {
if len(listenAddrs) == 0 && len(masqAddrCounts) == 0 {
return nil
}
return &natFamilyConfig{
nativeAddr: nativeAddr,
listenAddrs: views.MapOf(listenAddrs),
dstMasqAddrs: views.MapOf(dstMasqAddrs),
dstAddrToPeerKeyMapper: rt.Build(),
nativeAddr: nativeAddr,
listenAddrs: views.MapOf(listenAddrs),
dstMasqAddrs: &rt,
masqAddrCounts: masqAddrCounts,
}
}