tailscale/util/linuxfw/nftables_runner.go
Naman Sood d46a4eced5
util/linuxfw, wgengine: allow ingress to magicsock UDP port on Linux (#10370)
* util/linuxfw, wgengine: allow ingress to magicsock UDP port on Linux

Updates #9084.

Currently, we have to tell users to manually open UDP ports on Linux when
certain firewalls (like ufw) are enabled. This change automates the process of
adding and updating those firewall rules as magicsock changes what port it
listens on.

Signed-off-by: Naman Sood <mail@nsood.in>
2023-12-05 18:12:02 -05:00

1755 lines
51 KiB
Go

// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build linux
package linuxfw
import (
"encoding/binary"
"encoding/hex"
"errors"
"fmt"
"net"
"net/netip"
"reflect"
"strings"
"github.com/google/nftables"
"github.com/google/nftables/expr"
"golang.org/x/sys/unix"
"tailscale.com/net/tsaddr"
"tailscale.com/types/logger"
"tailscale.com/types/ptr"
)
const (
chainNameForward = "ts-forward"
chainNameInput = "ts-input"
chainNamePostrouting = "ts-postrouting"
)
// chainTypeRegular is an nftables chain that does not apply to a hook.
const chainTypeRegular = ""
type chainInfo struct {
table *nftables.Table
name string
chainType nftables.ChainType
chainHook *nftables.ChainHook
chainPriority *nftables.ChainPriority
chainPolicy *nftables.ChainPolicy
}
type nftable struct {
Proto nftables.TableFamily
Filter *nftables.Table
Nat *nftables.Table
}
// nftablesRunner implements a netfilterRunner using the netlink based nftables
// library. As nftables allows for arbitrary tables and chains, there is a need
// to follow conventions in order to integrate well with a surrounding
// ecosystem. The rules installed by nftablesRunner have the following
// properties:
// - Install rules that intend to take precedence over rules installed by
// other software. Tailscale provides packet filtering for tailnet traffic
// inside the daemon based on the tailnet ACL rules.
// - As nftables "accept" is not final, rules from high priority tables (low
// numbers) will fall through to lower priority tables (high numbers). In
// order to effectively be 'final', we install "jump" rules into conventional
// tables and chains that will reach an accept verdict inside those tables.
// - The table and chain conventions followed here are those used by
// `iptables-nft` and `ufw`, so that those tools co-exist and do not
// negatively affect Tailscale function.
type nftablesRunner struct {
conn *nftables.Conn
nft4 *nftable
nft6 *nftable
v6Available bool
v6NATAvailable bool
}
func (n *nftablesRunner) ensurePreroutingChain(dst netip.Addr) (*nftables.Table, *nftables.Chain, error) {
polAccept := nftables.ChainPolicyAccept
table := n.getNFTByAddr(dst)
nat, err := createTableIfNotExist(n.conn, table.Proto, "nat")
if err != nil {
return nil, nil, fmt.Errorf("error ensuring nat table: %w", err)
}
// ensure prerouting chain exists
preroutingCh, err := getOrCreateChain(n.conn, chainInfo{
table: nat,
name: "PREROUTING",
chainType: nftables.ChainTypeNAT,
chainHook: nftables.ChainHookPrerouting,
chainPriority: nftables.ChainPriorityNATDest,
chainPolicy: &polAccept,
})
if err != nil {
return nil, nil, fmt.Errorf("error ensuring prerouting chain: %w", err)
}
return nat, preroutingCh, nil
}
func (n *nftablesRunner) AddDNATRule(origDst netip.Addr, dst netip.Addr) error {
nat, preroutingCh, err := n.ensurePreroutingChain(dst)
if err != nil {
return err
}
var daddrOffset, fam, dadderLen uint32
if origDst.Is4() {
daddrOffset = 16
dadderLen = 4
fam = unix.NFPROTO_IPV4
} else {
daddrOffset = 24
dadderLen = 16
fam = unix.NFPROTO_IPV6
}
dnatRule := &nftables.Rule{
Table: nat,
Chain: preroutingCh,
Exprs: []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: daddrOffset,
Len: dadderLen,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: origDst.AsSlice(),
},
&expr.Immediate{
Register: 1,
Data: dst.AsSlice(),
},
&expr.NAT{
Type: expr.NATTypeDestNAT,
Family: fam,
RegAddrMin: 1,
},
},
}
n.conn.InsertRule(dnatRule)
return n.conn.Flush()
}
func (n *nftablesRunner) DNATNonTailscaleTraffic(tunname string, dst netip.Addr) error {
nat, preroutingCh, err := n.ensurePreroutingChain(dst)
if err != nil {
return err
}
var famConst uint32
if dst.Is4() {
famConst = unix.NFPROTO_IPV4
} else {
famConst = unix.NFPROTO_IPV6
}
dnatRule := &nftables.Rule{
Table: nat,
Chain: preroutingCh,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte(tunname),
},
&expr.Immediate{
Register: 1,
Data: dst.AsSlice(),
},
&expr.NAT{
Type: expr.NATTypeDestNAT,
Family: famConst,
RegAddrMin: 1,
},
},
}
n.conn.AddRule(dnatRule)
return n.conn.Flush()
}
func (n *nftablesRunner) AddSNATRuleForDst(src, dst netip.Addr) error {
polAccept := nftables.ChainPolicyAccept
table := n.getNFTByAddr(dst)
nat, err := createTableIfNotExist(n.conn, table.Proto, "nat")
if err != nil {
return fmt.Errorf("error ensuring nat table exists: %w", err)
}
// ensure postrouting chain exists
postRoutingCh, err := getOrCreateChain(n.conn, chainInfo{
table: nat,
name: "POSTROUTING",
chainType: nftables.ChainTypeNAT,
chainHook: nftables.ChainHookPostrouting,
chainPriority: nftables.ChainPriorityNATSource,
chainPolicy: &polAccept,
})
if err != nil {
return fmt.Errorf("error ensuring postrouting chain: %w", err)
}
var daddrOffset, fam, daddrLen uint32
if dst.Is4() {
daddrOffset = 16
daddrLen = 4
fam = unix.NFPROTO_IPV4
} else {
daddrOffset = 24
daddrLen = 16
fam = unix.NFPROTO_IPV6
}
snatRule := &nftables.Rule{
Table: nat,
Chain: postRoutingCh,
Exprs: []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: daddrOffset,
Len: daddrLen,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: dst.AsSlice(),
},
&expr.Immediate{
Register: 1,
Data: src.AsSlice(),
},
&expr.NAT{
Type: expr.NATTypeSourceNAT,
Family: fam,
RegAddrMin: 1,
},
},
}
n.conn.AddRule(snatRule)
return n.conn.Flush()
}
func (n *nftablesRunner) ClampMSSToPMTU(tun string, addr netip.Addr) error {
polAccept := nftables.ChainPolicyAccept
table := n.getNFTByAddr(addr)
filterTable, err := createTableIfNotExist(n.conn, table.Proto, "filter")
if err != nil {
return fmt.Errorf("error ensuring filter table: %w", err)
}
// ensure forwarding chain exists
fwChain, err := getOrCreateChain(n.conn, chainInfo{
table: filterTable,
name: "FORWARD",
chainType: nftables.ChainTypeFilter,
chainHook: nftables.ChainHookForward,
chainPriority: nftables.ChainPriorityFilter,
chainPolicy: &polAccept,
})
if err != nil {
return fmt.Errorf("error ensuring forward chain: %w", err)
}
clampRule := &nftables.Rule{
Table: filterTable,
Chain: fwChain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte(tun),
},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{unix.IPPROTO_TCP},
},
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 13,
Len: 1,
},
&expr.Bitwise{
DestRegister: 1,
SourceRegister: 1,
Len: 1,
Mask: []byte{0x02},
Xor: []byte{0x00},
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte{0x00},
},
&expr.Rt{
Register: 1,
Key: expr.RtTCPMSS,
},
&expr.Byteorder{
DestRegister: 1,
SourceRegister: 1,
Op: expr.ByteorderHton,
Len: 2,
Size: 2,
},
&expr.Exthdr{
SourceRegister: 1,
Type: 2,
Offset: 2,
Len: 2,
Op: expr.ExthdrOpTcpopt,
},
},
}
n.conn.AddRule(clampRule)
return n.conn.Flush()
}
// deleteTableIfExists deletes a nftables table via connection c if it exists
// within the given family.
func deleteTableIfExists(c *nftables.Conn, family nftables.TableFamily, name string) error {
t, err := getTableIfExists(c, family, name)
if err != nil {
return fmt.Errorf("get table: %w", err)
}
if t == nil {
// Table does not exist, so nothing to delete.
return nil
}
c.DelTable(t)
if err := c.Flush(); err != nil {
if t, err = getTableIfExists(c, family, name); t == nil && err == nil {
// Check if the table still exists. If it does not, then the error
// is due to the table not existing, so we can ignore it. Maybe a
// concurrent process deleted the table.
return nil
}
return fmt.Errorf("del table: %w", err)
}
return nil
}
// getTableIfExists returns the table with the given name from the given family
// if it exists. If none match, it returns (nil, nil).
func getTableIfExists(c *nftables.Conn, family nftables.TableFamily, name string) (*nftables.Table, error) {
tables, err := c.ListTables()
if err != nil {
return nil, fmt.Errorf("get tables: %w", err)
}
for _, table := range tables {
if table.Name == name && table.Family == family {
return table, nil
}
}
return nil, nil
}
// createTableIfNotExist creates a nftables table via connection c if it does
// not exist within the given family.
func createTableIfNotExist(c *nftables.Conn, family nftables.TableFamily, name string) (*nftables.Table, error) {
if t, err := getTableIfExists(c, family, name); err != nil {
return nil, fmt.Errorf("get table: %w", err)
} else if t != nil {
return t, nil
}
t := c.AddTable(&nftables.Table{
Family: family,
Name: name,
})
if err := c.Flush(); err != nil {
return nil, fmt.Errorf("add table: %w", err)
}
return t, nil
}
type errorChainNotFound struct {
chainName string
tableName string
}
func (e errorChainNotFound) Error() string {
return fmt.Sprintf("chain %s not found in table %s", e.chainName, e.tableName)
}
// getChainFromTable returns the chain with the given name from the given table.
// Note that a chain name is unique within a table.
func getChainFromTable(c *nftables.Conn, table *nftables.Table, name string) (*nftables.Chain, error) {
chains, err := c.ListChainsOfTableFamily(table.Family)
if err != nil {
return nil, fmt.Errorf("list chains: %w", err)
}
for _, chain := range chains {
// Table family is already checked so table name is unique
if chain.Table.Name == table.Name && chain.Name == name {
return chain, nil
}
}
return nil, errorChainNotFound{table.Name, name}
}
// isTSChain reports whether `name` begins with "ts-" (and is thus a
// Tailscale-managed chain).
func isTSChain(name string) bool {
return strings.HasPrefix(name, "ts-")
}
// createChainIfNotExist creates a chain with the given name in the given table
// if it does not exist.
func createChainIfNotExist(c *nftables.Conn, cinfo chainInfo) error {
_, err := getOrCreateChain(c, cinfo)
return err
}
func getOrCreateChain(c *nftables.Conn, cinfo chainInfo) (*nftables.Chain, error) {
chain, err := getChainFromTable(c, cinfo.table, cinfo.name)
if err != nil && !errors.Is(err, errorChainNotFound{cinfo.table.Name, cinfo.name}) {
return nil, fmt.Errorf("get chain: %w", err)
} else if err == nil {
// The chain already exists. If it is a TS chain, check the
// type/hook/priority, but for "conventional chains" assume they're what
// we expect (in case iptables-nft/ufw make minor behavior changes in
// the future).
if isTSChain(chain.Name) && (chain.Type != cinfo.chainType || chain.Hooknum != cinfo.chainHook || chain.Priority != cinfo.chainPriority) {
return nil, fmt.Errorf("chain %s already exists with different type/hook/priority", cinfo.name)
}
return chain, nil
}
_ = c.AddChain(&nftables.Chain{
Name: cinfo.name,
Table: cinfo.table,
Type: cinfo.chainType,
Hooknum: cinfo.chainHook,
Priority: cinfo.chainPriority,
Policy: cinfo.chainPolicy,
})
if err := c.Flush(); err != nil {
return nil, fmt.Errorf("add chain: %w", err)
}
return chain, nil
}
// NetfilterRunner abstracts helpers to run netfilter commands. It is
// implemented by linuxfw.IPTablesRunner and linuxfw.NfTablesRunner.
type NetfilterRunner interface {
// AddLoopbackRule adds a rule to permit loopback traffic to addr. This rule
// is added only if it does not already exist.
AddLoopbackRule(addr netip.Addr) error
// DelLoopbackRule removes the rule added by AddLoopbackRule.
DelLoopbackRule(addr netip.Addr) error
// AddHooks adds rules to conventional chains like "FORWARD", "INPUT" and
// "POSTROUTING" to jump from those chains to tailscale chains.
AddHooks() error
// DelHooks deletes rules added by AddHooks.
DelHooks(logf logger.Logf) error
// AddChains creates custom Tailscale chains.
AddChains() error
// DelChains removes chains added by AddChains.
DelChains() error
// AddBase adds rules reused by different other rules.
AddBase(tunname string) error
// DelBase removes rules added by AddBase.
DelBase() error
// AddSNATRule adds the netfilter rule to SNAT incoming traffic over
// the Tailscale interface destined for local subnets. An error is
// returned if the rule already exists.
AddSNATRule() error
// DelSNATRule removes the rule added by AddSNATRule.
DelSNATRule() error
// HasIPV6 reports true if the system supports IPv6.
HasIPV6() bool
// HasIPV6NAT reports true if the system supports IPv6 NAT.
HasIPV6NAT() bool
// AddDNATRule adds a rule to the nat/PREROUTING chain to DNAT traffic
// destined for the given original destination to the given new destination.
// This is used to forward all traffic destined for the Tailscale interface
// to the provided destination, as used in the Kubernetes ingress proxies.
AddDNATRule(origDst, dst netip.Addr) error
// AddSNATRuleForDst adds a rule to the nat/POSTROUTING chain to SNAT
// traffic destined for dst to src.
// This is used to forward traffic destined for the local machine over
// the Tailscale interface, as used in the Kubernetes egress proxies.
AddSNATRuleForDst(src, dst netip.Addr) error
// DNATNonTailscaleTraffic adds a rule to the nat/PREROUTING chain to DNAT
// all traffic inbound from any interface except exemptInterface to dst.
// This is used to forward traffic destined for the local machine over
// the Tailscale interface, as used in the Kubernetes egress proxies.//
DNATNonTailscaleTraffic(exemptInterface string, dst netip.Addr) error
// ClampMSSToPMTU adds a rule to the mangle/FORWARD chain to clamp MSS for
// traffic destined for the provided tun interface.
ClampMSSToPMTU(tun string, addr netip.Addr) error
// AddMagicsockPortRule adds a rule to the ts-input chain to accept
// incoming traffic on the specified port, to allow magicsock to
// communicate.
AddMagicsockPortRule(port uint16, network string) error
// DelMagicsockPortRule removes the rule created by AddMagicsockPortRule,
// if it exists.
DelMagicsockPortRule(port uint16, network string) error
}
// New creates a NetfilterRunner, auto-detecting whether to use
// nftables or iptables.
// As nftables is still experimental, iptables will be used unless
// either the TS_DEBUG_FIREWALL_MODE environment variable, or the prefHint
// parameter, is set to one of "nftables" or "auto".
func New(logf logger.Logf, prefHint string) (NetfilterRunner, error) {
mode := detectFirewallMode(logf, prefHint)
switch mode {
case FirewallModeIPTables:
return newIPTablesRunner(logf)
case FirewallModeNfTables:
return newNfTablesRunner(logf)
default:
return nil, fmt.Errorf("unknown firewall mode %v", mode)
}
}
// newNfTablesRunner creates a new nftablesRunner without guaranteeing
// the existence of the tables and chains.
func newNfTablesRunner(logf logger.Logf) (*nftablesRunner, error) {
conn, err := nftables.New()
if err != nil {
return nil, fmt.Errorf("nftables connection: %w", err)
}
nft4 := &nftable{Proto: nftables.TableFamilyIPv4}
v6err := checkIPv6(logf)
if v6err != nil {
logf("disabling tunneled IPv6 due to system IPv6 config: %v", v6err)
}
supportsV6 := v6err == nil
supportsV6NAT := supportsV6 && checkSupportsV6NAT()
var nft6 *nftable
if supportsV6 {
logf("v6nat availability: %v", supportsV6NAT)
nft6 = &nftable{Proto: nftables.TableFamilyIPv6}
}
// TODO(KevinLiang10): convert iptables rule to nftable rules if they exist in the iptables
return &nftablesRunner{
conn: conn,
nft4: nft4,
nft6: nft6,
v6Available: supportsV6,
v6NATAvailable: supportsV6NAT,
}, nil
}
// newLoadSaddrExpr creates a new nftables expression that loads the source
// address of the packet into the given register.
func newLoadSaddrExpr(proto nftables.TableFamily, destReg uint32) (expr.Any, error) {
switch proto {
case nftables.TableFamilyIPv4:
return &expr.Payload{
DestRegister: destReg,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
}, nil
case nftables.TableFamilyIPv6:
return &expr.Payload{
DestRegister: destReg,
Base: expr.PayloadBaseNetworkHeader,
Offset: 8,
Len: 16,
}, nil
default:
return nil, fmt.Errorf("table family %v is neither IPv4 nor IPv6", proto)
}
}
// newLoadDportExpr creates a new nftables express that loads the desination port
// of a TCP/UDP packet into the given register.
func newLoadDportExpr(destReg uint32) expr.Any {
return &expr.Payload{
DestRegister: destReg,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
}
}
// HasIPV6 reports true if the system supports IPv6.
func (n *nftablesRunner) HasIPV6() bool {
return n.v6Available
}
// HasIPV6NAT returns true if the system supports IPv6 NAT.
func (n *nftablesRunner) HasIPV6NAT() bool {
return n.v6NATAvailable
}
// findRule iterates through the rules to find the rule with matching expressions.
func findRule(conn *nftables.Conn, rule *nftables.Rule) (*nftables.Rule, error) {
rules, err := conn.GetRules(rule.Table, rule.Chain)
if err != nil {
return nil, fmt.Errorf("get nftables rules: %w", err)
}
if len(rules) == 0 {
return nil, nil
}
ruleLoop:
for _, r := range rules {
if len(r.Exprs) != len(rule.Exprs) {
continue
}
for i, e := range r.Exprs {
// Skip counter expressions, as they will not match.
if _, ok := e.(*expr.Counter); ok {
continue
}
if !reflect.DeepEqual(e, rule.Exprs[i]) {
continue ruleLoop
}
}
return r, nil
}
return nil, nil
}
func createLoopbackRule(
proto nftables.TableFamily,
table *nftables.Table,
chain *nftables.Chain,
addr netip.Addr,
) (*nftables.Rule, error) {
saddrExpr, err := newLoadSaddrExpr(proto, 1)
if err != nil {
return nil, fmt.Errorf("newLoadSaddrExpr: %w", err)
}
loopBackRule := &nftables.Rule{
Table: table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte("lo"),
},
saddrExpr,
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: addr.AsSlice(),
},
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
}
return loopBackRule, nil
}
// insertLoopbackRule inserts the TS loop back rule into
// the given chain as the first rule if it does not exist.
func insertLoopbackRule(
conn *nftables.Conn, proto nftables.TableFamily,
table *nftables.Table, chain *nftables.Chain, addr netip.Addr) error {
loopBackRule, err := createLoopbackRule(proto, table, chain, addr)
if err != nil {
return fmt.Errorf("create loopback rule: %w", err)
}
// If TestDial is set, we are running in test mode and we should not
// find rule because header will mismatch.
if conn.TestDial == nil {
// Check if the rule already exists.
rule, err := findRule(conn, loopBackRule)
if err != nil {
return fmt.Errorf("find rule: %w", err)
}
if rule != nil {
// Rule already exists, no need to insert.
return nil
}
}
// This inserts the rule to the top of the chain
_ = conn.InsertRule(loopBackRule)
if err = conn.Flush(); err != nil {
return fmt.Errorf("insert rule: %w", err)
}
return nil
}
// getNFTByAddr returns the nftables with correct IP family
// that we will be using for the given address.
func (n *nftablesRunner) getNFTByAddr(addr netip.Addr) *nftable {
if addr.Is6() {
return n.nft6
}
return n.nft4
}
// AddLoopbackRule adds an nftables rule to permit loopback traffic to
// a local Tailscale IP. This rule is added only if it does not already exist.
func (n *nftablesRunner) AddLoopbackRule(addr netip.Addr) error {
nf := n.getNFTByAddr(addr)
inputChain, err := getChainFromTable(n.conn, nf.Filter, chainNameInput)
if err != nil {
return fmt.Errorf("get input chain: %w", err)
}
if err := insertLoopbackRule(n.conn, nf.Proto, nf.Filter, inputChain, addr); err != nil {
return fmt.Errorf("add loopback rule: %w", err)
}
return nil
}
// DelLoopbackRule removes the nftables rule permitting loopback
// traffic to a Tailscale IP.
func (n *nftablesRunner) DelLoopbackRule(addr netip.Addr) error {
nf := n.getNFTByAddr(addr)
inputChain, err := getChainFromTable(n.conn, nf.Filter, chainNameInput)
if err != nil {
return fmt.Errorf("get input chain: %w", err)
}
loopBackRule, err := createLoopbackRule(nf.Proto, nf.Filter, inputChain, addr)
if err != nil {
return fmt.Errorf("create loopback rule: %w", err)
}
existingLoopBackRule, err := findRule(n.conn, loopBackRule)
if err != nil {
return fmt.Errorf("find loop back rule: %w", err)
}
if existingLoopBackRule == nil {
// Rule does not exist, no need to delete.
return nil
}
if err := n.conn.DelRule(existingLoopBackRule); err != nil {
return fmt.Errorf("delete rule: %w", err)
}
return n.conn.Flush()
}
// getTables gets the available nftable in nftables runner.
func (n *nftablesRunner) getTables() []*nftable {
if n.v6Available {
return []*nftable{n.nft4, n.nft6}
}
return []*nftable{n.nft4}
}
// getNATTables gets the available nftable in nftables runner.
// If the system does not support IPv6 NAT, only the IPv4 nftable
// will be returned.
func (n *nftablesRunner) getNATTables() []*nftable {
if n.v6NATAvailable {
return n.getTables()
}
return []*nftable{n.nft4}
}
// AddChains creates custom Tailscale chains in netfilter via nftables
// if the ts-chain doesn't already exist.
func (n *nftablesRunner) AddChains() error {
polAccept := nftables.ChainPolicyAccept
for _, table := range n.getTables() {
// Create the filter table if it doesn't exist, this table name is the same
// as the name used by iptables-nft and ufw. We install rules into the
// same conventional table so that `accept` verdicts from our jump
// chains are conclusive.
filter, err := createTableIfNotExist(n.conn, table.Proto, "filter")
if err != nil {
return fmt.Errorf("create table: %w", err)
}
table.Filter = filter
// Adding the "conventional chains" that are used by iptables-nft and ufw.
if err = createChainIfNotExist(n.conn, chainInfo{filter, "FORWARD", nftables.ChainTypeFilter, nftables.ChainHookForward, nftables.ChainPriorityFilter, &polAccept}); err != nil {
return fmt.Errorf("create forward chain: %w", err)
}
if err = createChainIfNotExist(n.conn, chainInfo{filter, "INPUT", nftables.ChainTypeFilter, nftables.ChainHookInput, nftables.ChainPriorityFilter, &polAccept}); err != nil {
return fmt.Errorf("create input chain: %w", err)
}
// Adding the tailscale chains that contain our rules.
if err = createChainIfNotExist(n.conn, chainInfo{filter, chainNameForward, chainTypeRegular, nil, nil, nil}); err != nil {
return fmt.Errorf("create forward chain: %w", err)
}
if err = createChainIfNotExist(n.conn, chainInfo{filter, chainNameInput, chainTypeRegular, nil, nil, nil}); err != nil {
return fmt.Errorf("create input chain: %w", err)
}
}
for _, table := range n.getNATTables() {
// Create the nat table if it doesn't exist, this table name is the same
// as the name used by iptables-nft and ufw. We install rules into the
// same conventional table so that `accept` verdicts from our jump
// chains are conclusive.
nat, err := createTableIfNotExist(n.conn, table.Proto, "nat")
if err != nil {
return fmt.Errorf("create table: %w", err)
}
table.Nat = nat
// Adding the "conventional chains" that are used by iptables-nft and ufw.
if err = createChainIfNotExist(n.conn, chainInfo{nat, "POSTROUTING", nftables.ChainTypeNAT, nftables.ChainHookPostrouting, nftables.ChainPriorityNATSource, &polAccept}); err != nil {
return fmt.Errorf("create postrouting chain: %w", err)
}
// Adding the tailscale chain that contains our rules.
if err = createChainIfNotExist(n.conn, chainInfo{nat, chainNamePostrouting, chainTypeRegular, nil, nil, nil}); err != nil {
return fmt.Errorf("create postrouting chain: %w", err)
}
}
return n.conn.Flush()
}
// These are dummy chains and tables we create to detect if nftables is
// available. We create them, then delete them. If we can create and delete
// them, then we can use nftables. If we can't, then we assume that we're
// running on a system that doesn't support nftables. See
// createDummyPostroutingChains.
const (
tsDummyChainName = "ts-test-postrouting"
tsDummyTableName = "ts-test-nat"
)
// createDummyPostroutingChains creates dummy postrouting chains in netfilter
// via netfilter via nftables, as a last resort measure to detect that nftables
// can be used. It cleans up the dummy chains after creation.
func (n *nftablesRunner) createDummyPostroutingChains() (retErr error) {
polAccept := ptr.To(nftables.ChainPolicyAccept)
for _, table := range n.getNATTables() {
nat, err := createTableIfNotExist(n.conn, table.Proto, tsDummyTableName)
if err != nil {
return fmt.Errorf("create nat table: %w", err)
}
defer func(fm nftables.TableFamily) {
if err := deleteTableIfExists(n.conn, table.Proto, tsDummyTableName); err != nil && retErr == nil {
retErr = fmt.Errorf("delete %q table: %w", tsDummyTableName, err)
}
}(table.Proto)
table.Nat = nat
if err = createChainIfNotExist(n.conn, chainInfo{nat, tsDummyChainName, nftables.ChainTypeNAT, nftables.ChainHookPostrouting, nftables.ChainPriorityNATSource, polAccept}); err != nil {
return fmt.Errorf("create %q chain: %w", tsDummyChainName, err)
}
if err := deleteChainIfExists(n.conn, nat, tsDummyChainName); err != nil {
return fmt.Errorf("delete %q chain: %w", tsDummyChainName, err)
}
}
return nil
}
// deleteChainIfExists deletes a chain if it exists.
func deleteChainIfExists(c *nftables.Conn, table *nftables.Table, name string) error {
chain, err := getChainFromTable(c, table, name)
if err != nil && !errors.Is(err, errorChainNotFound{table.Name, name}) {
return fmt.Errorf("get chain: %w", err)
} else if err != nil {
// If the chain doesn't exist, we don't need to delete it.
return nil
}
c.FlushChain(chain)
c.DelChain(chain)
if err := c.Flush(); err != nil {
return fmt.Errorf("flush and delete chain: %w", err)
}
return nil
}
// DelChains removes the custom Tailscale chains from netfilter via nftables.
func (n *nftablesRunner) DelChains() error {
for _, table := range n.getTables() {
if err := deleteChainIfExists(n.conn, table.Filter, chainNameForward); err != nil {
return fmt.Errorf("delete chain: %w", err)
}
if err := deleteChainIfExists(n.conn, table.Filter, chainNameInput); err != nil {
return fmt.Errorf("delete chain: %w", err)
}
}
if err := deleteChainIfExists(n.conn, n.nft4.Nat, chainNamePostrouting); err != nil {
return fmt.Errorf("delete chain: %w", err)
}
if n.v6NATAvailable {
if err := deleteChainIfExists(n.conn, n.nft6.Nat, chainNamePostrouting); err != nil {
return fmt.Errorf("delete chain: %w", err)
}
}
if err := n.conn.Flush(); err != nil {
return fmt.Errorf("flush: %w", err)
}
return nil
}
// createHookRule creates a rule to jump from a hooked chain to a regular chain.
func createHookRule(table *nftables.Table, fromChain *nftables.Chain, toChainName string) *nftables.Rule {
exprs := []expr.Any{
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictJump,
Chain: toChainName,
},
}
rule := &nftables.Rule{
Table: table,
Chain: fromChain,
Exprs: exprs,
}
return rule
}
// addHookRule adds a rule to jump from a hooked chain to a regular chain at top of the hooked chain.
func addHookRule(conn *nftables.Conn, table *nftables.Table, fromChain *nftables.Chain, toChainName string) error {
rule := createHookRule(table, fromChain, toChainName)
_ = conn.InsertRule(rule)
if err := conn.Flush(); err != nil {
return fmt.Errorf("flush add rule: %w", err)
}
return nil
}
// AddHooks is adding rules to conventional chains like "FORWARD", "INPUT" and "POSTROUTING"
// in tables and jump from those chains to tailscale chains.
func (n *nftablesRunner) AddHooks() error {
conn := n.conn
for _, table := range n.getTables() {
inputChain, err := getChainFromTable(conn, table.Filter, "INPUT")
if err != nil {
return fmt.Errorf("get INPUT chain: %w", err)
}
err = addHookRule(conn, table.Filter, inputChain, chainNameInput)
if err != nil {
return fmt.Errorf("Addhook: %w", err)
}
forwardChain, err := getChainFromTable(conn, table.Filter, "FORWARD")
if err != nil {
return fmt.Errorf("get FORWARD chain: %w", err)
}
err = addHookRule(conn, table.Filter, forwardChain, chainNameForward)
if err != nil {
return fmt.Errorf("Addhook: %w", err)
}
}
for _, table := range n.getNATTables() {
postroutingChain, err := getChainFromTable(conn, table.Nat, "POSTROUTING")
if err != nil {
return fmt.Errorf("get INPUT chain: %w", err)
}
err = addHookRule(conn, table.Nat, postroutingChain, chainNamePostrouting)
if err != nil {
return fmt.Errorf("Addhook: %w", err)
}
}
return nil
}
// delHookRule deletes a rule that jumps from a hooked chain to a regular chain.
func delHookRule(conn *nftables.Conn, table *nftables.Table, fromChain *nftables.Chain, toChainName string) error {
rule := createHookRule(table, fromChain, toChainName)
existingRule, err := findRule(conn, rule)
if err != nil {
return fmt.Errorf("Failed to find hook rule: %w", err)
}
if existingRule == nil {
return nil
}
_ = conn.DelRule(existingRule)
if err := conn.Flush(); err != nil {
return fmt.Errorf("flush del hook rule: %w", err)
}
return nil
}
// DelHooks is deleting the rules added to conventional chains to jump to tailscale chains.
func (n *nftablesRunner) DelHooks(logf logger.Logf) error {
conn := n.conn
for _, table := range n.getTables() {
inputChain, err := getChainFromTable(conn, table.Filter, "INPUT")
if err != nil {
return fmt.Errorf("get INPUT chain: %w", err)
}
err = delHookRule(conn, table.Filter, inputChain, chainNameInput)
if err != nil {
return fmt.Errorf("delhook: %w", err)
}
forwardChain, err := getChainFromTable(conn, table.Filter, "FORWARD")
if err != nil {
return fmt.Errorf("get FORWARD chain: %w", err)
}
err = delHookRule(conn, table.Filter, forwardChain, chainNameForward)
if err != nil {
return fmt.Errorf("delhook: %w", err)
}
}
for _, table := range n.getNATTables() {
postroutingChain, err := getChainFromTable(conn, table.Nat, "POSTROUTING")
if err != nil {
return fmt.Errorf("get INPUT chain: %w", err)
}
err = delHookRule(conn, table.Nat, postroutingChain, chainNamePostrouting)
if err != nil {
return fmt.Errorf("delhook: %w", err)
}
}
return nil
}
// maskof returns the mask of the given prefix in big endian bytes.
func maskof(pfx netip.Prefix) []byte {
mask := make([]byte, 4)
binary.BigEndian.PutUint32(mask, ^(uint32(0xffff_ffff) >> pfx.Bits()))
return mask
}
// createRangeRule creates a rule that matches packets with source IP from the give
// range (like CGNAT range or ChromeOSVM range) and the interface is not the tunname,
// and makes the given decision. Only IPv4 is supported.
func createRangeRule(
table *nftables.Table, chain *nftables.Chain,
tunname string, rng netip.Prefix, decision expr.VerdictKind,
) (*nftables.Rule, error) {
if rng.Addr().Is6() {
return nil, errors.New("IPv6 is not supported")
}
saddrExpr, err := newLoadSaddrExpr(nftables.TableFamilyIPv4, 1)
if err != nil {
return nil, fmt.Errorf("newLoadSaddrExpr: %w", err)
}
netip := rng.Addr().AsSlice()
mask := maskof(rng)
rule := &nftables.Rule{
Table: table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte(tunname),
},
saddrExpr,
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: mask,
Xor: []byte{0x00, 0x00, 0x00, 0x00},
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: netip,
},
&expr.Counter{},
&expr.Verdict{
Kind: decision,
},
},
}
return rule, nil
}
// addReturnChromeOSVMRangeRule adds a rule to return if the source IP
// is in the ChromeOS VM range.
func addReturnChromeOSVMRangeRule(c *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error {
rule, err := createRangeRule(table, chain, tunname, tsaddr.ChromeOSVMRange(), expr.VerdictReturn)
if err != nil {
return fmt.Errorf("create rule: %w", err)
}
_ = c.AddRule(rule)
if err = c.Flush(); err != nil {
return fmt.Errorf("add rule: %w", err)
}
return nil
}
// addDropCGNATRangeRule adds a rule to drop if the source IP is in the
// CGNAT range.
func addDropCGNATRangeRule(c *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error {
rule, err := createRangeRule(table, chain, tunname, tsaddr.CGNATRange(), expr.VerdictDrop)
if err != nil {
return fmt.Errorf("create rule: %w", err)
}
_ = c.AddRule(rule)
if err = c.Flush(); err != nil {
return fmt.Errorf("add rule: %w", err)
}
return nil
}
// createSetSubnetRouteMarkRule creates a rule to set the subnet route
// mark if the packet is from the given interface.
func createSetSubnetRouteMarkRule(table *nftables.Table, chain *nftables.Chain, tunname string) (*nftables.Rule, error) {
hexTsFwmarkMaskNeg := getTailscaleFwmarkMaskNeg()
hexTSSubnetRouteMark := getTailscaleSubnetRouteMark()
rule := &nftables.Rule{
Table: table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte(tunname),
},
&expr.Counter{},
&expr.Meta{Key: expr.MetaKeyMARK, Register: 1},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: hexTsFwmarkMaskNeg,
Xor: hexTSSubnetRouteMark,
},
&expr.Meta{
Key: expr.MetaKeyMARK,
SourceRegister: true,
Register: 1,
},
},
}
return rule, nil
}
// addSetSubnetRouteMarkRule adds a rule to set the subnet route mark
// if the packet is from the given interface.
func addSetSubnetRouteMarkRule(c *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error {
rule, err := createSetSubnetRouteMarkRule(table, chain, tunname)
if err != nil {
return fmt.Errorf("create rule: %w", err)
}
_ = c.AddRule(rule)
if err := c.Flush(); err != nil {
return fmt.Errorf("add rule: %w", err)
}
return nil
}
// createDropOutgoingPacketFromCGNATRangeRuleWithTunname creates a rule to drop
// outgoing packets from the CGNAT range.
func createDropOutgoingPacketFromCGNATRangeRuleWithTunname(table *nftables.Table, chain *nftables.Chain, tunname string) (*nftables.Rule, error) {
_, ipNet, err := net.ParseCIDR(tsaddr.CGNATRange().String())
if err != nil {
return nil, fmt.Errorf("parse cidr: %v", err)
}
mask, err := hex.DecodeString(ipNet.Mask.String())
if err != nil {
return nil, fmt.Errorf("decode mask: %v", err)
}
netip := ipNet.IP.Mask(ipNet.Mask).To4()
saddrExpr, err := newLoadSaddrExpr(nftables.TableFamilyIPv4, 1)
if err != nil {
return nil, fmt.Errorf("newLoadSaddrExpr: %v", err)
}
rule := &nftables.Rule{
Table: table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte(tunname),
},
saddrExpr,
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: mask,
Xor: []byte{0x00, 0x00, 0x00, 0x00},
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: netip,
},
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictDrop,
},
},
}
return rule, nil
}
// addDropOutgoingPacketFromCGNATRangeRuleWithTunname adds a rule to drop
// outgoing packets from the CGNAT range.
func addDropOutgoingPacketFromCGNATRangeRuleWithTunname(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error {
rule, err := createDropOutgoingPacketFromCGNATRangeRuleWithTunname(table, chain, tunname)
if err != nil {
return fmt.Errorf("create rule: %w", err)
}
_ = conn.AddRule(rule)
if err := conn.Flush(); err != nil {
return fmt.Errorf("add rule: %w", err)
}
return nil
}
// createAcceptOutgoingPacketRule creates a rule to accept outgoing packets
// from the given interface.
func createAcceptOutgoingPacketRule(table *nftables.Table, chain *nftables.Chain, tunname string) *nftables.Rule {
return &nftables.Rule{
Table: table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte(tunname),
},
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
}
}
// addAcceptOutgoingPacketRule adds a rule to accept outgoing packets
// from the given interface.
func addAcceptOutgoingPacketRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error {
rule := createAcceptOutgoingPacketRule(table, chain, tunname)
_ = conn.AddRule(rule)
if err := conn.Flush(); err != nil {
return fmt.Errorf("flush add rule: %w", err)
}
return nil
}
// createAcceptOnPortRule creates a rule to accept incoming packets to
// a given destination UDP port.
func createAcceptOnPortRule(table *nftables.Table, chain *nftables.Chain, port uint16) *nftables.Rule {
portBytes := make([]byte, 2)
binary.BigEndian.PutUint16(portBytes, port)
return &nftables.Rule{
Table: table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{
Key: expr.MetaKeyL4PROTO,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{unix.IPPROTO_UDP},
},
newLoadDportExpr(1),
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: portBytes,
},
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
}
}
// addAcceptOnPortRule adds a rule to accept incoming packets to
// a given destination UDP port.
func addAcceptOnPortRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, port uint16) error {
rule := createAcceptOnPortRule(table, chain, port)
_ = conn.AddRule(rule)
if err := conn.Flush(); err != nil {
return fmt.Errorf("flush add rule: %w", err)
}
return nil
}
// addAcceptOnPortRule removes a rule to accept incoming packets to
// a given destination UDP port.
func removeAcceptOnPortRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, port uint16) error {
rule := createAcceptOnPortRule(table, chain, port)
rule, err := findRule(conn, rule)
if err != nil {
return fmt.Errorf("find rule: %v", err)
}
_ = conn.DelRule(rule)
if err := conn.Flush(); err != nil {
return fmt.Errorf("flush del rule: %w", err)
}
return nil
}
// AddMagicsockPortRule adds a rule to nftables to allow incoming traffic on
// the specified UDP port, so magicsock can accept incoming connections.
// network must be either "udp4" or "udp6" - this determines whether the rule
// is added for IPv4 or IPv6.
func (n *nftablesRunner) AddMagicsockPortRule(port uint16, network string) error {
var filterTable *nftables.Table
switch network {
case "udp4":
filterTable = n.nft4.Filter
case "udp6":
filterTable = n.nft6.Filter
default:
return fmt.Errorf("unsupported network %s", network)
}
inputChain, err := getChainFromTable(n.conn, filterTable, chainNameInput)
if err != nil {
return fmt.Errorf("get input chain: %v", err)
}
err = addAcceptOnPortRule(n.conn, filterTable, inputChain, port)
if err != nil {
return fmt.Errorf("add accept on port rule: %v", err)
}
return nil
}
// DelMagicsockPortRule removes a rule added by AddMagicsockPortRule to accept
// incoming traffic on a particular UDP port.
// network must be either "udp4" or "udp6" - this determines whether the rule
// is removed for IPv4 or IPv6.
func (n *nftablesRunner) DelMagicsockPortRule(port uint16, network string) error {
var filterTable *nftables.Table
switch network {
case "udp4":
filterTable = n.nft4.Filter
case "udp6":
filterTable = n.nft6.Filter
default:
return fmt.Errorf("unsupported network %s", network)
}
inputChain, err := getChainFromTable(n.conn, filterTable, chainNameInput)
if err != nil {
return fmt.Errorf("get input chain: %v", err)
}
err = removeAcceptOnPortRule(n.conn, filterTable, inputChain, port)
if err != nil {
return fmt.Errorf("add accept on port rule: %v", err)
}
return nil
}
// createAcceptIncomingPacketRule creates a rule to accept incoming packets to
// the given interface.
func createAcceptIncomingPacketRule(table *nftables.Table, chain *nftables.Chain, tunname string) *nftables.Rule {
return &nftables.Rule{
Table: table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte(tunname),
},
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
}
}
func addAcceptIncomingPacketRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error {
rule := createAcceptIncomingPacketRule(table, chain, tunname)
_ = conn.AddRule(rule)
if err := conn.Flush(); err != nil {
return fmt.Errorf("flush add rule: %w", err)
}
return nil
}
// AddBase adds some basic processing rules.
func (n *nftablesRunner) AddBase(tunname string) error {
if err := n.addBase4(tunname); err != nil {
return fmt.Errorf("add base v4: %w", err)
}
if n.HasIPV6() {
if err := n.addBase6(tunname); err != nil {
return fmt.Errorf("add base v6: %w", err)
}
}
return nil
}
// addBase4 adds some basic IPv4 processing rules.
func (n *nftablesRunner) addBase4(tunname string) error {
conn := n.conn
inputChain, err := getChainFromTable(conn, n.nft4.Filter, chainNameInput)
if err != nil {
return fmt.Errorf("get input chain v4: %v", err)
}
if err = addReturnChromeOSVMRangeRule(conn, n.nft4.Filter, inputChain, tunname); err != nil {
return fmt.Errorf("add return chromeos vm range rule v4: %w", err)
}
if err = addDropCGNATRangeRule(conn, n.nft4.Filter, inputChain, tunname); err != nil {
return fmt.Errorf("add drop cgnat range rule v4: %w", err)
}
if err = addAcceptIncomingPacketRule(conn, n.nft4.Filter, inputChain, tunname); err != nil {
return fmt.Errorf("add accept incoming packet rule v4: %w", err)
}
forwardChain, err := getChainFromTable(conn, n.nft4.Filter, chainNameForward)
if err != nil {
return fmt.Errorf("get forward chain v4: %v", err)
}
if err = addSetSubnetRouteMarkRule(conn, n.nft4.Filter, forwardChain, tunname); err != nil {
return fmt.Errorf("add set subnet route mark rule v4: %w", err)
}
if err = addMatchSubnetRouteMarkRule(conn, n.nft4.Filter, forwardChain, Accept); err != nil {
return fmt.Errorf("add match subnet route mark rule v4: %w", err)
}
if err = addDropOutgoingPacketFromCGNATRangeRuleWithTunname(conn, n.nft4.Filter, forwardChain, tunname); err != nil {
return fmt.Errorf("add drop outgoing packet from cgnat range rule v4: %w", err)
}
if err = addAcceptOutgoingPacketRule(conn, n.nft4.Filter, forwardChain, tunname); err != nil {
return fmt.Errorf("add accept outgoing packet rule v4: %w", err)
}
if err = conn.Flush(); err != nil {
return fmt.Errorf("flush base v4: %w", err)
}
return nil
}
// addBase6 adds some basic IPv6 processing rules.
func (n *nftablesRunner) addBase6(tunname string) error {
conn := n.conn
inputChain, err := getChainFromTable(conn, n.nft6.Filter, chainNameInput)
if err != nil {
return fmt.Errorf("get input chain v4: %v", err)
}
if err = addAcceptIncomingPacketRule(conn, n.nft6.Filter, inputChain, tunname); err != nil {
return fmt.Errorf("add accept incoming packet rule v6: %w", err)
}
forwardChain, err := getChainFromTable(conn, n.nft6.Filter, chainNameForward)
if err != nil {
return fmt.Errorf("get forward chain v6: %w", err)
}
if err = addSetSubnetRouteMarkRule(conn, n.nft6.Filter, forwardChain, tunname); err != nil {
return fmt.Errorf("add set subnet route mark rule v6: %w", err)
}
if err = addMatchSubnetRouteMarkRule(conn, n.nft6.Filter, forwardChain, Accept); err != nil {
return fmt.Errorf("add match subnet route mark rule v6: %w", err)
}
if err = addAcceptOutgoingPacketRule(conn, n.nft6.Filter, forwardChain, tunname); err != nil {
return fmt.Errorf("add accept outgoing packet rule v6: %w", err)
}
if err = conn.Flush(); err != nil {
return fmt.Errorf("flush base v6: %w", err)
}
return nil
}
// DelBase empties, but does not remove, custom Tailscale chains from
// netfilter via iptables.
func (n *nftablesRunner) DelBase() error {
conn := n.conn
for _, table := range n.getTables() {
inputChain, err := getChainFromTable(conn, table.Filter, chainNameInput)
if err != nil {
return fmt.Errorf("get input chain: %v", err)
}
conn.FlushChain(inputChain)
forwardChain, err := getChainFromTable(conn, table.Filter, chainNameForward)
if err != nil {
return fmt.Errorf("get forward chain: %v", err)
}
conn.FlushChain(forwardChain)
}
for _, table := range n.getNATTables() {
postrouteChain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting)
if err != nil {
return fmt.Errorf("get postrouting chain v4: %v", err)
}
conn.FlushChain(postrouteChain)
}
return conn.Flush()
}
// createMatchSubnetRouteMarkRule creates a rule that matches packets
// with the subnet route mark and takes the specified action.
func createMatchSubnetRouteMarkRule(table *nftables.Table, chain *nftables.Chain, action MatchDecision) (*nftables.Rule, error) {
hexTSFwmarkMask := getTailscaleFwmarkMask()
hexTSSubnetRouteMark := getTailscaleSubnetRouteMark()
var endAction expr.Any
endAction = &expr.Verdict{Kind: expr.VerdictAccept}
if action == Masq {
endAction = &expr.Masq{}
}
exprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyMARK, Register: 1},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: hexTSFwmarkMask,
Xor: []byte{0x00, 0x00, 0x00, 0x00},
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: hexTSSubnetRouteMark,
},
&expr.Counter{},
endAction,
}
rule := &nftables.Rule{
Table: table,
Chain: chain,
Exprs: exprs,
}
return rule, nil
}
// addMatchSubnetRouteMarkRule adds a rule that matches packets with
// the subnet route mark and takes the specified action.
func addMatchSubnetRouteMarkRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, action MatchDecision) error {
rule, err := createMatchSubnetRouteMarkRule(table, chain, action)
if err != nil {
return fmt.Errorf("create match subnet route mark rule: %w", err)
}
_ = conn.AddRule(rule)
if err := conn.Flush(); err != nil {
return fmt.Errorf("flush add rule: %w", err)
}
return nil
}
// AddSNATRule adds a netfilter rule to SNAT traffic destined for
// local subnets.
func (n *nftablesRunner) AddSNATRule() error {
conn := n.conn
for _, table := range n.getNATTables() {
chain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting)
if err != nil {
return fmt.Errorf("get postrouting chain v4: %w", err)
}
if err = addMatchSubnetRouteMarkRule(conn, table.Nat, chain, Masq); err != nil {
return fmt.Errorf("add match subnet route mark rule v4: %w", err)
}
}
if err := conn.Flush(); err != nil {
return fmt.Errorf("flush add SNAT rule: %w", err)
}
return nil
}
// DelSNATRule removes the netfilter rule to SNAT traffic destined for
// local subnets. An error is returned if the rule does not exist.
func (n *nftablesRunner) DelSNATRule() error {
conn := n.conn
hexTSFwmarkMask := getTailscaleFwmarkMask()
hexTSSubnetRouteMark := getTailscaleSubnetRouteMark()
exprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyMARK, Register: 1},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: hexTSFwmarkMask,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: hexTSSubnetRouteMark,
},
&expr.Counter{},
&expr.Masq{},
}
for _, table := range n.getNATTables() {
chain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting)
if err != nil {
return fmt.Errorf("get postrouting chain v4: %w", err)
}
rule := &nftables.Rule{
Table: table.Nat,
Chain: chain,
Exprs: exprs,
}
SNATRule, err := findRule(conn, rule)
if err != nil {
return fmt.Errorf("find SNAT rule v4: %w", err)
}
if SNATRule != nil {
_ = conn.DelRule(SNATRule)
}
}
if err := conn.Flush(); err != nil {
return fmt.Errorf("flush del SNAT rule: %w", err)
}
return nil
}
// cleanupChain removes a jump rule from hookChainName to tsChainName, and then
// the entire chain tsChainName. Errors are logged, but attempts to remove both
// the jump rule and chain continue even if one errors.
func cleanupChain(logf logger.Logf, conn *nftables.Conn, table *nftables.Table, hookChainName, tsChainName string) {
// remove the jump first, before removing the jump destination.
defaultChain, err := getChainFromTable(conn, table, hookChainName)
if err != nil && !errors.Is(err, errorChainNotFound{table.Name, hookChainName}) {
logf("cleanup: did not find default chain: %s", err)
}
if !errors.Is(err, errorChainNotFound{table.Name, hookChainName}) {
// delete hook in convention chain
_ = delHookRule(conn, table, defaultChain, tsChainName)
}
tsChain, err := getChainFromTable(conn, table, tsChainName)
if err != nil && !errors.Is(err, errorChainNotFound{table.Name, tsChainName}) {
logf("cleanup: did not find ts-chain: %s", err)
}
if tsChain != nil {
// flush and delete ts-chain
conn.FlushChain(tsChain)
conn.DelChain(tsChain)
err = conn.Flush()
logf("cleanup: delete and flush chain %s: %s", tsChainName, err)
}
}
// NfTablesCleanUp removes all Tailscale added nftables rules.
// Any errors that occur are logged to the provided logf.
func NfTablesCleanUp(logf logger.Logf) {
conn, err := nftables.New()
if err != nil {
logf("cleanup: nftables connection: %s", err)
}
tables, err := conn.ListTables() // both v4 and v6
if err != nil {
logf("cleanup: list tables: %s", err)
}
for _, table := range tables {
// These table names were used briefly in 1.48.0.
if table.Name == "ts-filter" || table.Name == "ts-nat" {
conn.DelTable(table)
if err := conn.Flush(); err != nil {
logf("cleanup: flush delete table %s: %s", table.Name, err)
}
}
if table.Name == "filter" {
cleanupChain(logf, conn, table, "INPUT", chainNameInput)
cleanupChain(logf, conn, table, "FORWARD", chainNameForward)
}
if table.Name == "nat" {
cleanupChain(logf, conn, table, "POSTROUTING", chainNamePostrouting)
}
}
}