tailscale/util/linuxfw/nftables_runner.go
James Tucker ba6ec42f6d util/linuxfw: add missing input rule to the tailscale tun
Add an explicit accept rule for input to the tun interface, as a mirror
to the explicit rule to accept output from the tun interface.

The rule matches any packet in to our tun interface and accepts it, and
the rule is positioned and prioritized such that it should be evaluated
prior to conventional ufw/iptables/nft rules.

Updates #391
Fixes #7332
Updates #9084

Signed-off-by: James Tucker <james@tailscale.com>
2023-10-10 17:22:47 -07:00

1226 lines
36 KiB
Go

// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build linux
package linuxfw
import (
"encoding/binary"
"encoding/hex"
"errors"
"fmt"
"net"
"net/netip"
"reflect"
"strings"
"github.com/google/nftables"
"github.com/google/nftables/expr"
"tailscale.com/net/tsaddr"
"tailscale.com/types/logger"
)
const (
chainNameForward = "ts-forward"
chainNameInput = "ts-input"
chainNamePostrouting = "ts-postrouting"
)
// chainTypeRegular is an nftables chain that does not apply to a hook.
const chainTypeRegular = ""
type chainInfo struct {
table *nftables.Table
name string
chainType nftables.ChainType
chainHook *nftables.ChainHook
chainPriority *nftables.ChainPriority
chainPolicy *nftables.ChainPolicy
}
type nftable struct {
Proto nftables.TableFamily
Filter *nftables.Table
Nat *nftables.Table
}
// nftablesRunner implements a netfilterRunner using the netlink based nftables
// library. As nftables allows for arbitrary tables and chains, there is a need
// to follow conventions in order to integrate well with a surrounding
// ecosystem. The rules installed by nftablesRunner have the following
// properties:
// - Install rules that intend to take precedence over rules installed by
// other software. Tailscale provides packet filtering for tailnet traffic
// inside the daemon based on the tailnet ACL rules.
// - As nftables "accept" is not final, rules from high priority tables (low
// numbers) will fall through to lower priority tables (high numbers). In
// order to effectively be 'final', we install "jump" rules into conventional
// tables and chains that will reach an accept verdict inside those tables.
// - The table and chain conventions followed here are those used by
// `iptables-nft` and `ufw`, so that those tools co-exist and do not
// negatively affect Tailscale function.
type nftablesRunner struct {
conn *nftables.Conn
nft4 *nftable
nft6 *nftable
v6Available bool
v6NATAvailable bool
}
// createTableIfNotExist creates a nftables table via connection c if it does not exist within the given family.
func createTableIfNotExist(c *nftables.Conn, family nftables.TableFamily, name string) (*nftables.Table, error) {
tables, err := c.ListTables()
if err != nil {
return nil, fmt.Errorf("get tables: %w", err)
}
for _, table := range tables {
if table.Name == name && table.Family == family {
return table, nil
}
}
t := c.AddTable(&nftables.Table{
Family: family,
Name: name,
})
if err := c.Flush(); err != nil {
return nil, fmt.Errorf("add table: %w", err)
}
return t, nil
}
type errorChainNotFound struct {
chainName string
tableName string
}
func (e errorChainNotFound) Error() string {
return fmt.Sprintf("chain %s not found in table %s", e.chainName, e.tableName)
}
// getChainFromTable returns the chain with the given name from the given table.
// Note that a chain name is unique within a table.
func getChainFromTable(c *nftables.Conn, table *nftables.Table, name string) (*nftables.Chain, error) {
chains, err := c.ListChainsOfTableFamily(table.Family)
if err != nil {
return nil, fmt.Errorf("list chains: %w", err)
}
for _, chain := range chains {
// Table family is already checked so table name is unique
if chain.Table.Name == table.Name && chain.Name == name {
return chain, nil
}
}
return nil, errorChainNotFound{table.Name, name}
}
// getChainsFromTable returns all chains from the given table.
func getChainsFromTable(c *nftables.Conn, table *nftables.Table) ([]*nftables.Chain, error) {
chains, err := c.ListChainsOfTableFamily(table.Family)
if err != nil {
return nil, fmt.Errorf("list chains: %w", err)
}
var ret []*nftables.Chain
for _, chain := range chains {
// Table family is already checked so table name is unique
if chain.Table.Name == table.Name {
ret = append(ret, chain)
}
}
return ret, nil
}
// isTSChain reports whether `name` begins with "ts-" (and is thus a
// Tailscale-managed chain).
func isTSChain(name string) bool {
return strings.HasPrefix(name, "ts-")
}
// createChainIfNotExist creates a chain with the given name in the given table
// if it does not exist.
func createChainIfNotExist(c *nftables.Conn, cinfo chainInfo) error {
chain, err := getChainFromTable(c, cinfo.table, cinfo.name)
if err != nil && !errors.Is(err, errorChainNotFound{cinfo.table.Name, cinfo.name}) {
return fmt.Errorf("get chain: %w", err)
} else if err == nil {
// The chain already exists. If it is a TS chain, check the
// type/hook/priority, but for "conventional chains" assume they're what
// we expect (in case iptables-nft/ufw make minor behavior changes in
// the future).
if isTSChain(chain.Name) && (chain.Type != cinfo.chainType || chain.Hooknum != cinfo.chainHook || chain.Priority != cinfo.chainPriority) {
return fmt.Errorf("chain %s already exists with different type/hook/priority", cinfo.name)
}
return nil
}
_ = c.AddChain(&nftables.Chain{
Name: cinfo.name,
Table: cinfo.table,
Type: cinfo.chainType,
Hooknum: cinfo.chainHook,
Priority: cinfo.chainPriority,
Policy: cinfo.chainPolicy,
})
if err := c.Flush(); err != nil {
return fmt.Errorf("add chain: %w", err)
}
return nil
}
// NewNfTablesRunner creates a new nftablesRunner without guaranteeing
// the existence of the tables and chains.
func NewNfTablesRunner(logf logger.Logf) (*nftablesRunner, error) {
conn, err := nftables.New()
if err != nil {
return nil, fmt.Errorf("nftables connection: %w", err)
}
nft4 := &nftable{Proto: nftables.TableFamilyIPv4}
v6err := checkIPv6(logf)
if v6err != nil {
logf("disabling tunneled IPv6 due to system IPv6 config: %v", v6err)
}
supportsV6 := v6err == nil
supportsV6NAT := supportsV6 && checkSupportsV6NAT()
var nft6 *nftable
if supportsV6 {
logf("v6nat availability: %v", supportsV6NAT)
nft6 = &nftable{Proto: nftables.TableFamilyIPv6}
}
// TODO(KevinLiang10): convert iptables rule to nftable rules if they exist in the iptables
return &nftablesRunner{
conn: conn,
nft4: nft4,
nft6: nft6,
v6Available: supportsV6,
v6NATAvailable: supportsV6NAT,
}, nil
}
// newLoadSaddrExpr creates a new nftables expression that loads the source
// address of the packet into the given register.
func newLoadSaddrExpr(proto nftables.TableFamily, destReg uint32) (expr.Any, error) {
switch proto {
case nftables.TableFamilyIPv4:
return &expr.Payload{
DestRegister: destReg,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
}, nil
case nftables.TableFamilyIPv6:
return &expr.Payload{
DestRegister: destReg,
Base: expr.PayloadBaseNetworkHeader,
Offset: 8,
Len: 16,
}, nil
default:
return nil, fmt.Errorf("table family %v is neither IPv4 nor IPv6", proto)
}
}
// HasIPV6 returns true if the system supports IPv6.
func (n *nftablesRunner) HasIPV6() bool {
return n.v6Available
}
// HasIPV6NAT returns true if the system supports IPv6 NAT.
func (n *nftablesRunner) HasIPV6NAT() bool {
return n.v6NATAvailable
}
// findRule iterates through the rules to find the rule with matching expressions.
func findRule(conn *nftables.Conn, rule *nftables.Rule) (*nftables.Rule, error) {
rules, err := conn.GetRules(rule.Table, rule.Chain)
if err != nil {
return nil, fmt.Errorf("get nftables rules: %w", err)
}
if len(rules) == 0 {
return nil, nil
}
ruleLoop:
for _, r := range rules {
if len(r.Exprs) != len(rule.Exprs) {
continue
}
for i, e := range r.Exprs {
// Skip counter expressions, as they will not match.
if _, ok := e.(*expr.Counter); ok {
continue
}
if !reflect.DeepEqual(e, rule.Exprs[i]) {
continue ruleLoop
}
}
return r, nil
}
return nil, nil
}
func createLoopbackRule(
proto nftables.TableFamily,
table *nftables.Table,
chain *nftables.Chain,
addr netip.Addr,
) (*nftables.Rule, error) {
saddrExpr, err := newLoadSaddrExpr(proto, 1)
if err != nil {
return nil, fmt.Errorf("newLoadSaddrExpr: %w", err)
}
loopBackRule := &nftables.Rule{
Table: table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte("lo"),
},
saddrExpr,
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: addr.AsSlice(),
},
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
}
return loopBackRule, nil
}
// insertLoopbackRule inserts the TS loop back rule into
// the given chain as the first rule if it does not exist.
func insertLoopbackRule(
conn *nftables.Conn, proto nftables.TableFamily,
table *nftables.Table, chain *nftables.Chain, addr netip.Addr) error {
loopBackRule, err := createLoopbackRule(proto, table, chain, addr)
if err != nil {
return fmt.Errorf("create loopback rule: %w", err)
}
// If TestDial is set, we are running in test mode and we should not
// find rule because header will mismatch.
if conn.TestDial == nil {
// Check if the rule already exists.
rule, err := findRule(conn, loopBackRule)
if err != nil {
return fmt.Errorf("find rule: %w", err)
}
if rule != nil {
// Rule already exists, no need to insert.
return nil
}
}
// This inserts the rule to the top of the chain
_ = conn.InsertRule(loopBackRule)
if err = conn.Flush(); err != nil {
return fmt.Errorf("insert rule: %w", err)
}
return nil
}
// getNFTByAddr returns the nftables with correct IP family
// that we will be using for the given address.
func (n *nftablesRunner) getNFTByAddr(addr netip.Addr) *nftable {
if addr.Is6() {
return n.nft6
}
return n.nft4
}
// AddLoopbackRule adds an nftables rule to permit loopback traffic to
// a local Tailscale IP. This rule is added only if it does not already exist.
func (n *nftablesRunner) AddLoopbackRule(addr netip.Addr) error {
nf := n.getNFTByAddr(addr)
inputChain, err := getChainFromTable(n.conn, nf.Filter, chainNameInput)
if err != nil {
return fmt.Errorf("get input chain: %w", err)
}
if err := insertLoopbackRule(n.conn, nf.Proto, nf.Filter, inputChain, addr); err != nil {
return fmt.Errorf("add loopback rule: %w", err)
}
return nil
}
// DelLoopbackRule removes the nftables rule permitting loopback
// traffic to a Tailscale IP.
func (n *nftablesRunner) DelLoopbackRule(addr netip.Addr) error {
nf := n.getNFTByAddr(addr)
inputChain, err := getChainFromTable(n.conn, nf.Filter, chainNameInput)
if err != nil {
return fmt.Errorf("get input chain: %w", err)
}
loopBackRule, err := createLoopbackRule(nf.Proto, nf.Filter, inputChain, addr)
if err != nil {
return fmt.Errorf("create loopback rule: %w", err)
}
existingLoopBackRule, err := findRule(n.conn, loopBackRule)
if err != nil {
return fmt.Errorf("find loop back rule: %w", err)
}
if existingLoopBackRule == nil {
// Rule does not exist, no need to delete.
return nil
}
if err := n.conn.DelRule(existingLoopBackRule); err != nil {
return fmt.Errorf("delete rule: %w", err)
}
return n.conn.Flush()
}
// getTables gets the available nftable in nftables runner.
func (n *nftablesRunner) getTables() []*nftable {
if n.v6Available {
return []*nftable{n.nft4, n.nft6}
}
return []*nftable{n.nft4}
}
// getNATTables gets the available nftable in nftables runner.
// If the system does not support IPv6 NAT, only the IPv4 nftable
// will be returned.
func (n *nftablesRunner) getNATTables() []*nftable {
if n.v6NATAvailable {
return n.getTables()
}
return []*nftable{n.nft4}
}
// AddChains creates custom Tailscale chains in netfilter via nftables
// if the ts-chain doesn't already exist.
func (n *nftablesRunner) AddChains() error {
polAccept := nftables.ChainPolicyAccept
for _, table := range n.getTables() {
// Create the filter table if it doesn't exist, this table name is the same
// as the name used by iptables-nft and ufw. We install rules into the
// same conventional table so that `accept` verdicts from our jump
// chains are conclusive.
filter, err := createTableIfNotExist(n.conn, table.Proto, "filter")
if err != nil {
return fmt.Errorf("create table: %w", err)
}
table.Filter = filter
// Adding the "conventional chains" that are used by iptables-nft and ufw.
if err = createChainIfNotExist(n.conn, chainInfo{filter, "FORWARD", nftables.ChainTypeFilter, nftables.ChainHookForward, nftables.ChainPriorityFilter, &polAccept}); err != nil {
return fmt.Errorf("create forward chain: %w", err)
}
if err = createChainIfNotExist(n.conn, chainInfo{filter, "INPUT", nftables.ChainTypeFilter, nftables.ChainHookInput, nftables.ChainPriorityFilter, &polAccept}); err != nil {
return fmt.Errorf("create input chain: %w", err)
}
// Adding the tailscale chains that contain our rules.
if err = createChainIfNotExist(n.conn, chainInfo{filter, chainNameForward, chainTypeRegular, nil, nil, nil}); err != nil {
return fmt.Errorf("create forward chain: %w", err)
}
if err = createChainIfNotExist(n.conn, chainInfo{filter, chainNameInput, chainTypeRegular, nil, nil, nil}); err != nil {
return fmt.Errorf("create input chain: %w", err)
}
}
for _, table := range n.getNATTables() {
// Create the nat table if it doesn't exist, this table name is the same
// as the name used by iptables-nft and ufw. We install rules into the
// same conventional table so that `accept` verdicts from our jump
// chains are conclusive.
nat, err := createTableIfNotExist(n.conn, table.Proto, "nat")
if err != nil {
return fmt.Errorf("create table: %w", err)
}
table.Nat = nat
// Adding the "conventional chains" that are used by iptables-nft and ufw.
if err = createChainIfNotExist(n.conn, chainInfo{nat, "POSTROUTING", nftables.ChainTypeNAT, nftables.ChainHookPostrouting, nftables.ChainPriorityNATSource, &polAccept}); err != nil {
return fmt.Errorf("create postrouting chain: %w", err)
}
// Adding the tailscale chain that contains our rules.
if err = createChainIfNotExist(n.conn, chainInfo{nat, chainNamePostrouting, chainTypeRegular, nil, nil, nil}); err != nil {
return fmt.Errorf("create postrouting chain: %w", err)
}
}
return n.conn.Flush()
}
// deleteChainIfExists deletes a chain if it exists.
func deleteChainIfExists(c *nftables.Conn, table *nftables.Table, name string) error {
chain, err := getChainFromTable(c, table, name)
if err != nil && !errors.Is(err, errorChainNotFound{table.Name, name}) {
return fmt.Errorf("get chain: %w", err)
} else if err != nil {
// If the chain doesn't exist, we don't need to delete it.
return nil
}
c.FlushChain(chain)
c.DelChain(chain)
if err := c.Flush(); err != nil {
return fmt.Errorf("flush and delete chain: %w", err)
}
return nil
}
// DelChains removes the custom Tailscale chains from netfilter via nftables.
func (n *nftablesRunner) DelChains() error {
for _, table := range n.getTables() {
if err := deleteChainIfExists(n.conn, table.Filter, chainNameForward); err != nil {
return fmt.Errorf("delete chain: %w", err)
}
if err := deleteChainIfExists(n.conn, table.Filter, chainNameInput); err != nil {
return fmt.Errorf("delete chain: %w", err)
}
}
if err := deleteChainIfExists(n.conn, n.nft4.Nat, chainNamePostrouting); err != nil {
return fmt.Errorf("delete chain: %w", err)
}
if n.v6NATAvailable {
if err := deleteChainIfExists(n.conn, n.nft6.Nat, chainNamePostrouting); err != nil {
return fmt.Errorf("delete chain: %w", err)
}
}
if err := n.conn.Flush(); err != nil {
return fmt.Errorf("flush: %w", err)
}
return nil
}
// createHookRule creates a rule to jump from a hooked chain to a regular chain.
func createHookRule(table *nftables.Table, fromChain *nftables.Chain, toChainName string) *nftables.Rule {
exprs := []expr.Any{
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictJump,
Chain: toChainName,
},
}
rule := &nftables.Rule{
Table: table,
Chain: fromChain,
Exprs: exprs,
}
return rule
}
// addHookRule adds a rule to jump from a hooked chain to a regular chain at top of the hooked chain.
func addHookRule(conn *nftables.Conn, table *nftables.Table, fromChain *nftables.Chain, toChainName string) error {
rule := createHookRule(table, fromChain, toChainName)
_ = conn.InsertRule(rule)
if err := conn.Flush(); err != nil {
return fmt.Errorf("flush add rule: %w", err)
}
return nil
}
// AddHooks is adding rules to conventional chains like "FORWARD", "INPUT" and "POSTROUTING"
// in tables and jump from those chains to tailscale chains.
func (n *nftablesRunner) AddHooks() error {
conn := n.conn
for _, table := range n.getTables() {
inputChain, err := getChainFromTable(conn, table.Filter, "INPUT")
if err != nil {
return fmt.Errorf("get INPUT chain: %w", err)
}
err = addHookRule(conn, table.Filter, inputChain, chainNameInput)
if err != nil {
return fmt.Errorf("Addhook: %w", err)
}
forwardChain, err := getChainFromTable(conn, table.Filter, "FORWARD")
if err != nil {
return fmt.Errorf("get FORWARD chain: %w", err)
}
err = addHookRule(conn, table.Filter, forwardChain, chainNameForward)
if err != nil {
return fmt.Errorf("Addhook: %w", err)
}
}
for _, table := range n.getNATTables() {
postroutingChain, err := getChainFromTable(conn, table.Nat, "POSTROUTING")
if err != nil {
return fmt.Errorf("get INPUT chain: %w", err)
}
err = addHookRule(conn, table.Nat, postroutingChain, chainNamePostrouting)
if err != nil {
return fmt.Errorf("Addhook: %w", err)
}
}
return nil
}
// delHookRule deletes a rule that jumps from a hooked chain to a regular chain.
func delHookRule(conn *nftables.Conn, table *nftables.Table, fromChain *nftables.Chain, toChainName string) error {
rule := createHookRule(table, fromChain, toChainName)
existingRule, err := findRule(conn, rule)
if err != nil {
return fmt.Errorf("Failed to find hook rule: %w", err)
}
if existingRule == nil {
return nil
}
_ = conn.DelRule(existingRule)
if err := conn.Flush(); err != nil {
return fmt.Errorf("flush del hook rule: %w", err)
}
return nil
}
// DelHooks is deleting the rules added to conventional chains to jump to tailscale chains.
func (n *nftablesRunner) DelHooks(logf logger.Logf) error {
conn := n.conn
for _, table := range n.getTables() {
inputChain, err := getChainFromTable(conn, table.Filter, "INPUT")
if err != nil {
return fmt.Errorf("get INPUT chain: %w", err)
}
err = delHookRule(conn, table.Filter, inputChain, chainNameInput)
if err != nil {
return fmt.Errorf("delhook: %w", err)
}
forwardChain, err := getChainFromTable(conn, table.Filter, "FORWARD")
if err != nil {
return fmt.Errorf("get FORWARD chain: %w", err)
}
err = delHookRule(conn, table.Filter, forwardChain, chainNameForward)
if err != nil {
return fmt.Errorf("delhook: %w", err)
}
}
for _, table := range n.getNATTables() {
postroutingChain, err := getChainFromTable(conn, table.Nat, "POSTROUTING")
if err != nil {
return fmt.Errorf("get INPUT chain: %w", err)
}
err = delHookRule(conn, table.Nat, postroutingChain, chainNamePostrouting)
if err != nil {
return fmt.Errorf("delhook: %w", err)
}
}
return nil
}
// maskof returns the mask of the given prefix in big endian bytes.
func maskof(pfx netip.Prefix) []byte {
mask := make([]byte, 4)
binary.BigEndian.PutUint32(mask, ^(uint32(0xffff_ffff) >> pfx.Bits()))
return mask
}
// createRangeRule creates a rule that matches packets with source IP from the give
// range (like CGNAT range or ChromeOSVM range) and the interface is not the tunname,
// and makes the given decision. Only IPv4 is supported.
func createRangeRule(
table *nftables.Table, chain *nftables.Chain,
tunname string, rng netip.Prefix, decision expr.VerdictKind,
) (*nftables.Rule, error) {
if rng.Addr().Is6() {
return nil, errors.New("IPv6 is not supported")
}
saddrExpr, err := newLoadSaddrExpr(nftables.TableFamilyIPv4, 1)
if err != nil {
return nil, fmt.Errorf("newLoadSaddrExpr: %w", err)
}
netip := rng.Addr().AsSlice()
mask := maskof(rng)
rule := &nftables.Rule{
Table: table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte(tunname),
},
saddrExpr,
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: mask,
Xor: []byte{0x00, 0x00, 0x00, 0x00},
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: netip,
},
&expr.Counter{},
&expr.Verdict{
Kind: decision,
},
},
}
return rule, nil
}
// addReturnChromeOSVMRangeRule adds a rule to return if the source IP
// is in the ChromeOS VM range.
func addReturnChromeOSVMRangeRule(c *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error {
rule, err := createRangeRule(table, chain, tunname, tsaddr.ChromeOSVMRange(), expr.VerdictReturn)
if err != nil {
return fmt.Errorf("create rule: %w", err)
}
_ = c.AddRule(rule)
if err = c.Flush(); err != nil {
return fmt.Errorf("add rule: %w", err)
}
return nil
}
// addDropCGNATRangeRule adds a rule to drop if the source IP is in the
// CGNAT range.
func addDropCGNATRangeRule(c *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error {
rule, err := createRangeRule(table, chain, tunname, tsaddr.CGNATRange(), expr.VerdictDrop)
if err != nil {
return fmt.Errorf("create rule: %w", err)
}
_ = c.AddRule(rule)
if err = c.Flush(); err != nil {
return fmt.Errorf("add rule: %w", err)
}
return nil
}
// createSetSubnetRouteMarkRule creates a rule to set the subnet route
// mark if the packet is from the given interface.
func createSetSubnetRouteMarkRule(table *nftables.Table, chain *nftables.Chain, tunname string) (*nftables.Rule, error) {
hexTsFwmarkMaskNeg := getTailscaleFwmarkMaskNeg()
hexTSSubnetRouteMark := getTailscaleSubnetRouteMark()
rule := &nftables.Rule{
Table: table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte(tunname),
},
&expr.Counter{},
&expr.Meta{Key: expr.MetaKeyMARK, Register: 1},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: hexTsFwmarkMaskNeg,
Xor: hexTSSubnetRouteMark,
},
&expr.Meta{
Key: expr.MetaKeyMARK,
SourceRegister: true,
Register: 1,
},
},
}
return rule, nil
}
// addSetSubnetRouteMarkRule adds a rule to set the subnet route mark
// if the packet is from the given interface.
func addSetSubnetRouteMarkRule(c *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error {
rule, err := createSetSubnetRouteMarkRule(table, chain, tunname)
if err != nil {
return fmt.Errorf("create rule: %w", err)
}
_ = c.AddRule(rule)
if err := c.Flush(); err != nil {
return fmt.Errorf("add rule: %w", err)
}
return nil
}
// createDropOutgoingPacketFromCGNATRangeRuleWithTunname creates a rule to drop
// outgoing packets from the CGNAT range.
func createDropOutgoingPacketFromCGNATRangeRuleWithTunname(table *nftables.Table, chain *nftables.Chain, tunname string) (*nftables.Rule, error) {
_, ipNet, err := net.ParseCIDR(tsaddr.CGNATRange().String())
if err != nil {
return nil, fmt.Errorf("parse cidr: %v", err)
}
mask, err := hex.DecodeString(ipNet.Mask.String())
if err != nil {
return nil, fmt.Errorf("decode mask: %v", err)
}
netip := ipNet.IP.Mask(ipNet.Mask).To4()
saddrExpr, err := newLoadSaddrExpr(nftables.TableFamilyIPv4, 1)
if err != nil {
return nil, fmt.Errorf("newLoadSaddrExpr: %v", err)
}
rule := &nftables.Rule{
Table: table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte(tunname),
},
saddrExpr,
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: mask,
Xor: []byte{0x00, 0x00, 0x00, 0x00},
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: netip,
},
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictDrop,
},
},
}
return rule, nil
}
// addDropOutgoingPacketFromCGNATRangeRuleWithTunname adds a rule to drop
// outgoing packets from the CGNAT range.
func addDropOutgoingPacketFromCGNATRangeRuleWithTunname(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error {
rule, err := createDropOutgoingPacketFromCGNATRangeRuleWithTunname(table, chain, tunname)
if err != nil {
return fmt.Errorf("create rule: %w", err)
}
_ = conn.AddRule(rule)
if err := conn.Flush(); err != nil {
return fmt.Errorf("add rule: %w", err)
}
return nil
}
// createAcceptOutgoingPacketRule creates a rule to accept outgoing packets
// from the given interface.
func createAcceptOutgoingPacketRule(table *nftables.Table, chain *nftables.Chain, tunname string) *nftables.Rule {
return &nftables.Rule{
Table: table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte(tunname),
},
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
}
}
// addAcceptOutgoingPacketRule adds a rule to accept outgoing packets
// from the given interface.
func addAcceptOutgoingPacketRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error {
rule := createAcceptOutgoingPacketRule(table, chain, tunname)
_ = conn.AddRule(rule)
if err := conn.Flush(); err != nil {
return fmt.Errorf("flush add rule: %w", err)
}
return nil
}
// createAcceptIncomingPacketRule creates a rule to accept incoming packets to
// the given interface.
func createAcceptIncomingPacketRule(table *nftables.Table, chain *nftables.Chain, tunname string) *nftables.Rule {
return &nftables.Rule{
Table: table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte(tunname),
},
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
}
}
func addAcceptIncomingPacketRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error {
rule := createAcceptIncomingPacketRule(table, chain, tunname)
_ = conn.AddRule(rule)
if err := conn.Flush(); err != nil {
return fmt.Errorf("flush add rule: %w", err)
}
return nil
}
// AddBase adds some basic processing rules.
func (n *nftablesRunner) AddBase(tunname string) error {
if err := n.addBase4(tunname); err != nil {
return fmt.Errorf("add base v4: %w", err)
}
if n.HasIPV6() {
if err := n.addBase6(tunname); err != nil {
return fmt.Errorf("add base v6: %w", err)
}
}
return nil
}
// addBase4 adds some basic IPv4 processing rules.
func (n *nftablesRunner) addBase4(tunname string) error {
conn := n.conn
inputChain, err := getChainFromTable(conn, n.nft4.Filter, chainNameInput)
if err != nil {
return fmt.Errorf("get input chain v4: %v", err)
}
if err = addReturnChromeOSVMRangeRule(conn, n.nft4.Filter, inputChain, tunname); err != nil {
return fmt.Errorf("add return chromeos vm range rule v4: %w", err)
}
if err = addDropCGNATRangeRule(conn, n.nft4.Filter, inputChain, tunname); err != nil {
return fmt.Errorf("add drop cgnat range rule v4: %w", err)
}
if err = addAcceptIncomingPacketRule(conn, n.nft4.Filter, inputChain, tunname); err != nil {
return fmt.Errorf("add accept incoming packet rule v4: %w", err)
}
forwardChain, err := getChainFromTable(conn, n.nft4.Filter, chainNameForward)
if err != nil {
return fmt.Errorf("get forward chain v4: %v", err)
}
if err = addSetSubnetRouteMarkRule(conn, n.nft4.Filter, forwardChain, tunname); err != nil {
return fmt.Errorf("add set subnet route mark rule v4: %w", err)
}
if err = addMatchSubnetRouteMarkRule(conn, n.nft4.Filter, forwardChain, Accept); err != nil {
return fmt.Errorf("add match subnet route mark rule v4: %w", err)
}
if err = addDropOutgoingPacketFromCGNATRangeRuleWithTunname(conn, n.nft4.Filter, forwardChain, tunname); err != nil {
return fmt.Errorf("add drop outgoing packet from cgnat range rule v4: %w", err)
}
if err = addAcceptOutgoingPacketRule(conn, n.nft4.Filter, forwardChain, tunname); err != nil {
return fmt.Errorf("add accept outgoing packet rule v4: %w", err)
}
if err = conn.Flush(); err != nil {
return fmt.Errorf("flush base v4: %w", err)
}
return nil
}
// addBase6 adds some basic IPv6 processing rules.
func (n *nftablesRunner) addBase6(tunname string) error {
conn := n.conn
inputChain, err := getChainFromTable(conn, n.nft6.Filter, chainNameInput)
if err != nil {
return fmt.Errorf("get input chain v4: %v", err)
}
if err = addAcceptIncomingPacketRule(conn, n.nft6.Filter, inputChain, tunname); err != nil {
return fmt.Errorf("add accept incoming packet rule v6: %w", err)
}
forwardChain, err := getChainFromTable(conn, n.nft6.Filter, chainNameForward)
if err != nil {
return fmt.Errorf("get forward chain v6: %w", err)
}
if err = addSetSubnetRouteMarkRule(conn, n.nft6.Filter, forwardChain, tunname); err != nil {
return fmt.Errorf("add set subnet route mark rule v6: %w", err)
}
if err = addMatchSubnetRouteMarkRule(conn, n.nft6.Filter, forwardChain, Accept); err != nil {
return fmt.Errorf("add match subnet route mark rule v6: %w", err)
}
if err = addAcceptOutgoingPacketRule(conn, n.nft6.Filter, forwardChain, tunname); err != nil {
return fmt.Errorf("add accept outgoing packet rule v6: %w", err)
}
if err = conn.Flush(); err != nil {
return fmt.Errorf("flush base v6: %w", err)
}
return nil
}
// DelBase empties, but does not remove, custom Tailscale chains from
// netfilter via iptables.
func (n *nftablesRunner) DelBase() error {
conn := n.conn
for _, table := range n.getTables() {
inputChain, err := getChainFromTable(conn, table.Filter, chainNameInput)
if err != nil {
return fmt.Errorf("get input chain: %v", err)
}
conn.FlushChain(inputChain)
forwardChain, err := getChainFromTable(conn, table.Filter, chainNameForward)
if err != nil {
return fmt.Errorf("get forward chain: %v", err)
}
conn.FlushChain(forwardChain)
}
for _, table := range n.getNATTables() {
postrouteChain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting)
if err != nil {
return fmt.Errorf("get postrouting chain v4: %v", err)
}
conn.FlushChain(postrouteChain)
}
return conn.Flush()
}
// createMatchSubnetRouteMarkRule creates a rule that matches packets
// with the subnet route mark and takes the specified action.
func createMatchSubnetRouteMarkRule(table *nftables.Table, chain *nftables.Chain, action MatchDecision) (*nftables.Rule, error) {
hexTSFwmarkMask := getTailscaleFwmarkMask()
hexTSSubnetRouteMark := getTailscaleSubnetRouteMark()
var endAction expr.Any
endAction = &expr.Verdict{Kind: expr.VerdictAccept}
if action == Masq {
endAction = &expr.Masq{}
}
exprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyMARK, Register: 1},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: hexTSFwmarkMask,
Xor: []byte{0x00, 0x00, 0x00, 0x00},
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: hexTSSubnetRouteMark,
},
&expr.Counter{},
endAction,
}
rule := &nftables.Rule{
Table: table,
Chain: chain,
Exprs: exprs,
}
return rule, nil
}
// addMatchSubnetRouteMarkRule adds a rule that matches packets with
// the subnet route mark and takes the specified action.
func addMatchSubnetRouteMarkRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, action MatchDecision) error {
rule, err := createMatchSubnetRouteMarkRule(table, chain, action)
if err != nil {
return fmt.Errorf("create match subnet route mark rule: %w", err)
}
_ = conn.AddRule(rule)
if err := conn.Flush(); err != nil {
return fmt.Errorf("flush add rule: %w", err)
}
return nil
}
// AddSNATRule adds a netfilter rule to SNAT traffic destined for
// local subnets.
func (n *nftablesRunner) AddSNATRule() error {
conn := n.conn
for _, table := range n.getNATTables() {
chain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting)
if err != nil {
return fmt.Errorf("get postrouting chain v4: %w", err)
}
if err = addMatchSubnetRouteMarkRule(conn, table.Nat, chain, Masq); err != nil {
return fmt.Errorf("add match subnet route mark rule v4: %w", err)
}
}
if err := conn.Flush(); err != nil {
return fmt.Errorf("flush add SNAT rule: %w", err)
}
return nil
}
// DelSNATRule removes the netfilter rule to SNAT traffic destined for
// local subnets. An error is returned if the rule does not exist.
func (n *nftablesRunner) DelSNATRule() error {
conn := n.conn
hexTSFwmarkMask := getTailscaleFwmarkMask()
hexTSSubnetRouteMark := getTailscaleSubnetRouteMark()
exprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyMARK, Register: 1},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: hexTSFwmarkMask,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: hexTSSubnetRouteMark,
},
&expr.Counter{},
&expr.Masq{},
}
for _, table := range n.getNATTables() {
chain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting)
if err != nil {
return fmt.Errorf("get postrouting chain v4: %w", err)
}
rule := &nftables.Rule{
Table: table.Nat,
Chain: chain,
Exprs: exprs,
}
SNATRule, err := findRule(conn, rule)
if err != nil {
return fmt.Errorf("find SNAT rule v4: %w", err)
}
if SNATRule != nil {
_ = conn.DelRule(SNATRule)
}
}
if err := conn.Flush(); err != nil {
return fmt.Errorf("flush del SNAT rule: %w", err)
}
return nil
}
// cleanupChain removes a jump rule from hookChainName to tsChainName, and then
// the entire chain tsChainName. Errors are logged, but attempts to remove both
// the jump rule and chain continue even if one errors.
func cleanupChain(logf logger.Logf, conn *nftables.Conn, table *nftables.Table, hookChainName, tsChainName string) {
// remove the jump first, before removing the jump destination.
defaultChain, err := getChainFromTable(conn, table, hookChainName)
if err != nil && !errors.Is(err, errorChainNotFound{table.Name, hookChainName}) {
logf("cleanup: did not find default chain: %s", err)
}
if !errors.Is(err, errorChainNotFound{table.Name, hookChainName}) {
// delete hook in convention chain
_ = delHookRule(conn, table, defaultChain, tsChainName)
}
tsChain, err := getChainFromTable(conn, table, tsChainName)
if err != nil && !errors.Is(err, errorChainNotFound{table.Name, tsChainName}) {
logf("cleanup: did not find ts-chain: %s", err)
}
if tsChain != nil {
// flush and delete ts-chain
conn.FlushChain(tsChain)
conn.DelChain(tsChain)
err = conn.Flush()
logf("cleanup: delete and flush chain %s: %s", tsChainName, err)
}
}
// NfTablesCleanUp removes all Tailscale added nftables rules.
// Any errors that occur are logged to the provided logf.
func NfTablesCleanUp(logf logger.Logf) {
conn, err := nftables.New()
if err != nil {
logf("cleanup: nftables connection: %s", err)
}
tables, err := conn.ListTables() // both v4 and v6
if err != nil {
logf("cleanup: list tables: %s", err)
}
for _, table := range tables {
// These table names were used briefly in 1.48.0.
if table.Name == "ts-filter" || table.Name == "ts-nat" {
conn.DelTable(table)
if err := conn.Flush(); err != nil {
logf("cleanup: flush delete table %s: %s", table.Name, err)
}
}
if table.Name == "filter" {
cleanupChain(logf, conn, table, "INPUT", chainNameInput)
cleanupChain(logf, conn, table, "FORWARD", chainNameForward)
}
if table.Name == "nat" {
cleanupChain(logf, conn, table, "POSTROUTING", chainNamePostrouting)
}
}
}