mirror of
https://github.com/tailscale/tailscale.git
synced 2024-12-04 23:45:34 +00:00
c3a8e63100
We were previously using the netlink API to see if there are chains/rules that already exist. This works fine in environments where there is either full nftable support or no support at all. However, we have identified certain environments which have partial nftable support and the only feasible way of detecting such an environment is to try to create some of the chains that we need. This adds a check to create a dummy postrouting chain which is immediately deleted. The goal of the check is to ensure we are able to use nftables and that it won't error out later. This check is only done in the path where we detected that the system has no preexisting nftable rules. Updates #5621 Updates #8555 Updates #8762 Signed-off-by: Maisem Ali <maisem@tailscale.com>
1613 lines
47 KiB
Go
1613 lines
47 KiB
Go
// Copyright (c) Tailscale Inc & AUTHORS
|
|
// SPDX-License-Identifier: BSD-3-Clause
|
|
|
|
//go:build linux
|
|
|
|
package linuxfw
|
|
|
|
import (
|
|
"encoding/binary"
|
|
"encoding/hex"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"net/netip"
|
|
"reflect"
|
|
"strings"
|
|
|
|
"github.com/google/nftables"
|
|
"github.com/google/nftables/expr"
|
|
"golang.org/x/sys/unix"
|
|
"tailscale.com/net/tsaddr"
|
|
"tailscale.com/types/logger"
|
|
"tailscale.com/types/ptr"
|
|
)
|
|
|
|
const (
|
|
chainNameForward = "ts-forward"
|
|
chainNameInput = "ts-input"
|
|
chainNamePostrouting = "ts-postrouting"
|
|
)
|
|
|
|
// chainTypeRegular is an nftables chain that does not apply to a hook.
|
|
const chainTypeRegular = ""
|
|
|
|
type chainInfo struct {
|
|
table *nftables.Table
|
|
name string
|
|
chainType nftables.ChainType
|
|
chainHook *nftables.ChainHook
|
|
chainPriority *nftables.ChainPriority
|
|
chainPolicy *nftables.ChainPolicy
|
|
}
|
|
|
|
type nftable struct {
|
|
Proto nftables.TableFamily
|
|
Filter *nftables.Table
|
|
Nat *nftables.Table
|
|
}
|
|
|
|
// nftablesRunner implements a netfilterRunner using the netlink based nftables
|
|
// library. As nftables allows for arbitrary tables and chains, there is a need
|
|
// to follow conventions in order to integrate well with a surrounding
|
|
// ecosystem. The rules installed by nftablesRunner have the following
|
|
// properties:
|
|
// - Install rules that intend to take precedence over rules installed by
|
|
// other software. Tailscale provides packet filtering for tailnet traffic
|
|
// inside the daemon based on the tailnet ACL rules.
|
|
// - As nftables "accept" is not final, rules from high priority tables (low
|
|
// numbers) will fall through to lower priority tables (high numbers). In
|
|
// order to effectively be 'final', we install "jump" rules into conventional
|
|
// tables and chains that will reach an accept verdict inside those tables.
|
|
// - The table and chain conventions followed here are those used by
|
|
// `iptables-nft` and `ufw`, so that those tools co-exist and do not
|
|
// negatively affect Tailscale function.
|
|
type nftablesRunner struct {
|
|
conn *nftables.Conn
|
|
nft4 *nftable
|
|
nft6 *nftable
|
|
|
|
v6Available bool
|
|
v6NATAvailable bool
|
|
}
|
|
|
|
func (n *nftablesRunner) ensurePreroutingChain(dst netip.Addr) (*nftables.Table, *nftables.Chain, error) {
|
|
polAccept := nftables.ChainPolicyAccept
|
|
table := n.getNFTByAddr(dst)
|
|
nat, err := createTableIfNotExist(n.conn, table.Proto, "nat")
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("error ensuring nat table: %w", err)
|
|
}
|
|
|
|
// ensure prerouting chain exists
|
|
preroutingCh, err := getOrCreateChain(n.conn, chainInfo{
|
|
table: nat,
|
|
name: "PREROUTING",
|
|
chainType: nftables.ChainTypeNAT,
|
|
chainHook: nftables.ChainHookPrerouting,
|
|
chainPriority: nftables.ChainPriorityNATDest,
|
|
chainPolicy: &polAccept,
|
|
})
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("error ensuring prerouting chain: %w", err)
|
|
}
|
|
return nat, preroutingCh, nil
|
|
}
|
|
|
|
func (n *nftablesRunner) AddDNATRule(origDst netip.Addr, dst netip.Addr) error {
|
|
nat, preroutingCh, err := n.ensurePreroutingChain(dst)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
var daddrOffset, fam, dadderLen uint32
|
|
if origDst.Is4() {
|
|
daddrOffset = 16
|
|
dadderLen = 4
|
|
fam = unix.NFPROTO_IPV4
|
|
} else {
|
|
daddrOffset = 24
|
|
dadderLen = 16
|
|
fam = unix.NFPROTO_IPV6
|
|
}
|
|
|
|
dnatRule := &nftables.Rule{
|
|
Table: nat,
|
|
Chain: preroutingCh,
|
|
Exprs: []expr.Any{
|
|
&expr.Payload{
|
|
DestRegister: 1,
|
|
Base: expr.PayloadBaseNetworkHeader,
|
|
Offset: daddrOffset,
|
|
Len: dadderLen,
|
|
},
|
|
&expr.Cmp{
|
|
Op: expr.CmpOpEq,
|
|
Register: 1,
|
|
Data: origDst.AsSlice(),
|
|
},
|
|
&expr.Immediate{
|
|
Register: 1,
|
|
Data: dst.AsSlice(),
|
|
},
|
|
&expr.NAT{
|
|
Type: expr.NATTypeDestNAT,
|
|
Family: fam,
|
|
RegAddrMin: 1,
|
|
},
|
|
},
|
|
}
|
|
n.conn.InsertRule(dnatRule)
|
|
return n.conn.Flush()
|
|
}
|
|
|
|
func (n *nftablesRunner) DNATNonTailscaleTraffic(tunname string, dst netip.Addr) error {
|
|
nat, preroutingCh, err := n.ensurePreroutingChain(dst)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
var famConst uint32
|
|
if dst.Is4() {
|
|
famConst = unix.NFPROTO_IPV4
|
|
} else {
|
|
famConst = unix.NFPROTO_IPV6
|
|
}
|
|
|
|
dnatRule := &nftables.Rule{
|
|
Table: nat,
|
|
Chain: preroutingCh,
|
|
Exprs: []expr.Any{
|
|
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
|
&expr.Cmp{
|
|
Op: expr.CmpOpNeq,
|
|
Register: 1,
|
|
Data: []byte(tunname),
|
|
},
|
|
&expr.Immediate{
|
|
Register: 1,
|
|
Data: dst.AsSlice(),
|
|
},
|
|
&expr.NAT{
|
|
Type: expr.NATTypeDestNAT,
|
|
Family: famConst,
|
|
RegAddrMin: 1,
|
|
},
|
|
},
|
|
}
|
|
n.conn.AddRule(dnatRule)
|
|
return n.conn.Flush()
|
|
}
|
|
|
|
func (n *nftablesRunner) AddSNATRuleForDst(src, dst netip.Addr) error {
|
|
polAccept := nftables.ChainPolicyAccept
|
|
table := n.getNFTByAddr(dst)
|
|
nat, err := createTableIfNotExist(n.conn, table.Proto, "nat")
|
|
if err != nil {
|
|
return fmt.Errorf("error ensuring nat table exists: %w", err)
|
|
}
|
|
|
|
// ensure postrouting chain exists
|
|
postRoutingCh, err := getOrCreateChain(n.conn, chainInfo{
|
|
table: nat,
|
|
name: "POSTROUTING",
|
|
chainType: nftables.ChainTypeNAT,
|
|
chainHook: nftables.ChainHookPostrouting,
|
|
chainPriority: nftables.ChainPriorityNATSource,
|
|
chainPolicy: &polAccept,
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("error ensuring postrouting chain: %w", err)
|
|
}
|
|
var daddrOffset, fam, daddrLen uint32
|
|
if dst.Is4() {
|
|
daddrOffset = 16
|
|
daddrLen = 4
|
|
fam = unix.NFPROTO_IPV4
|
|
} else {
|
|
daddrOffset = 24
|
|
daddrLen = 16
|
|
fam = unix.NFPROTO_IPV6
|
|
}
|
|
|
|
snatRule := &nftables.Rule{
|
|
Table: nat,
|
|
Chain: postRoutingCh,
|
|
Exprs: []expr.Any{
|
|
&expr.Payload{
|
|
DestRegister: 1,
|
|
Base: expr.PayloadBaseNetworkHeader,
|
|
Offset: daddrOffset,
|
|
Len: daddrLen,
|
|
},
|
|
&expr.Cmp{
|
|
Op: expr.CmpOpEq,
|
|
Register: 1,
|
|
Data: dst.AsSlice(),
|
|
},
|
|
&expr.Immediate{
|
|
Register: 1,
|
|
Data: src.AsSlice(),
|
|
},
|
|
&expr.NAT{
|
|
Type: expr.NATTypeSourceNAT,
|
|
Family: fam,
|
|
RegAddrMin: 1,
|
|
},
|
|
},
|
|
}
|
|
n.conn.AddRule(snatRule)
|
|
return n.conn.Flush()
|
|
}
|
|
|
|
func (n *nftablesRunner) ClampMSSToPMTU(tun string, addr netip.Addr) error {
|
|
polAccept := nftables.ChainPolicyAccept
|
|
table := n.getNFTByAddr(addr)
|
|
filterTable, err := createTableIfNotExist(n.conn, table.Proto, "filter")
|
|
if err != nil {
|
|
return fmt.Errorf("error ensuring filter table: %w", err)
|
|
}
|
|
|
|
// ensure forwarding chain exists
|
|
fwChain, err := getOrCreateChain(n.conn, chainInfo{
|
|
table: filterTable,
|
|
name: "FORWARD",
|
|
chainType: nftables.ChainTypeFilter,
|
|
chainHook: nftables.ChainHookForward,
|
|
chainPriority: nftables.ChainPriorityFilter,
|
|
chainPolicy: &polAccept,
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("error ensuring forward chain: %w", err)
|
|
}
|
|
|
|
clampRule := &nftables.Rule{
|
|
Table: filterTable,
|
|
Chain: fwChain,
|
|
Exprs: []expr.Any{
|
|
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
|
&expr.Cmp{
|
|
Op: expr.CmpOpEq,
|
|
Register: 1,
|
|
Data: []byte(tun),
|
|
},
|
|
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
|
&expr.Cmp{
|
|
Op: expr.CmpOpEq,
|
|
Register: 1,
|
|
Data: []byte{unix.IPPROTO_TCP},
|
|
},
|
|
&expr.Payload{
|
|
DestRegister: 1,
|
|
Base: expr.PayloadBaseTransportHeader,
|
|
Offset: 13,
|
|
Len: 1,
|
|
},
|
|
&expr.Bitwise{
|
|
DestRegister: 1,
|
|
SourceRegister: 1,
|
|
Len: 1,
|
|
Mask: []byte{0x02},
|
|
Xor: []byte{0x00},
|
|
},
|
|
&expr.Cmp{
|
|
Op: expr.CmpOpNeq,
|
|
Register: 1,
|
|
Data: []byte{0x00},
|
|
},
|
|
&expr.Rt{
|
|
Register: 1,
|
|
Key: expr.RtTCPMSS,
|
|
},
|
|
&expr.Byteorder{
|
|
DestRegister: 1,
|
|
SourceRegister: 1,
|
|
Op: expr.ByteorderHton,
|
|
Len: 2,
|
|
Size: 2,
|
|
},
|
|
&expr.Exthdr{
|
|
SourceRegister: 1,
|
|
Type: 2,
|
|
Offset: 2,
|
|
Len: 2,
|
|
Op: expr.ExthdrOpTcpopt,
|
|
},
|
|
},
|
|
}
|
|
n.conn.AddRule(clampRule)
|
|
return n.conn.Flush()
|
|
}
|
|
|
|
// deleteTableIfExists deletes a nftables table via connection c if it exists
|
|
// within the given family.
|
|
func deleteTableIfExists(c *nftables.Conn, family nftables.TableFamily, name string) error {
|
|
t, err := getTableIfExists(c, family, name)
|
|
if err != nil {
|
|
return fmt.Errorf("get table: %w", err)
|
|
}
|
|
if t == nil {
|
|
// Table does not exist, so nothing to delete.
|
|
return nil
|
|
}
|
|
c.DelTable(t)
|
|
if err := c.Flush(); err != nil {
|
|
if t, err = getTableIfExists(c, family, name); t == nil && err == nil {
|
|
// Check if the table still exists. If it does not, then the error
|
|
// is due to the table not existing, so we can ignore it. Maybe a
|
|
// concurrent process deleted the table.
|
|
return nil
|
|
}
|
|
return fmt.Errorf("del table: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// getTableIfExists returns the table with the given name from the given family
|
|
// if it exists. If none match, it returns (nil, nil).
|
|
func getTableIfExists(c *nftables.Conn, family nftables.TableFamily, name string) (*nftables.Table, error) {
|
|
tables, err := c.ListTables()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("get tables: %w", err)
|
|
}
|
|
for _, table := range tables {
|
|
if table.Name == name && table.Family == family {
|
|
return table, nil
|
|
}
|
|
}
|
|
return nil, nil
|
|
}
|
|
|
|
// createTableIfNotExist creates a nftables table via connection c if it does
|
|
// not exist within the given family.
|
|
func createTableIfNotExist(c *nftables.Conn, family nftables.TableFamily, name string) (*nftables.Table, error) {
|
|
if t, err := getTableIfExists(c, family, name); err != nil {
|
|
return nil, fmt.Errorf("get table: %w", err)
|
|
} else if t != nil {
|
|
return t, nil
|
|
}
|
|
t := c.AddTable(&nftables.Table{
|
|
Family: family,
|
|
Name: name,
|
|
})
|
|
if err := c.Flush(); err != nil {
|
|
return nil, fmt.Errorf("add table: %w", err)
|
|
}
|
|
return t, nil
|
|
}
|
|
|
|
type errorChainNotFound struct {
|
|
chainName string
|
|
tableName string
|
|
}
|
|
|
|
func (e errorChainNotFound) Error() string {
|
|
return fmt.Sprintf("chain %s not found in table %s", e.chainName, e.tableName)
|
|
}
|
|
|
|
// getChainFromTable returns the chain with the given name from the given table.
|
|
// Note that a chain name is unique within a table.
|
|
func getChainFromTable(c *nftables.Conn, table *nftables.Table, name string) (*nftables.Chain, error) {
|
|
chains, err := c.ListChainsOfTableFamily(table.Family)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("list chains: %w", err)
|
|
}
|
|
|
|
for _, chain := range chains {
|
|
// Table family is already checked so table name is unique
|
|
if chain.Table.Name == table.Name && chain.Name == name {
|
|
return chain, nil
|
|
}
|
|
}
|
|
|
|
return nil, errorChainNotFound{table.Name, name}
|
|
}
|
|
|
|
// isTSChain reports whether `name` begins with "ts-" (and is thus a
|
|
// Tailscale-managed chain).
|
|
func isTSChain(name string) bool {
|
|
return strings.HasPrefix(name, "ts-")
|
|
}
|
|
|
|
// createChainIfNotExist creates a chain with the given name in the given table
|
|
// if it does not exist.
|
|
func createChainIfNotExist(c *nftables.Conn, cinfo chainInfo) error {
|
|
_, err := getOrCreateChain(c, cinfo)
|
|
return err
|
|
}
|
|
|
|
func getOrCreateChain(c *nftables.Conn, cinfo chainInfo) (*nftables.Chain, error) {
|
|
chain, err := getChainFromTable(c, cinfo.table, cinfo.name)
|
|
if err != nil && !errors.Is(err, errorChainNotFound{cinfo.table.Name, cinfo.name}) {
|
|
return nil, fmt.Errorf("get chain: %w", err)
|
|
} else if err == nil {
|
|
// The chain already exists. If it is a TS chain, check the
|
|
// type/hook/priority, but for "conventional chains" assume they're what
|
|
// we expect (in case iptables-nft/ufw make minor behavior changes in
|
|
// the future).
|
|
if isTSChain(chain.Name) && (chain.Type != cinfo.chainType || chain.Hooknum != cinfo.chainHook || chain.Priority != cinfo.chainPriority) {
|
|
return nil, fmt.Errorf("chain %s already exists with different type/hook/priority", cinfo.name)
|
|
}
|
|
return chain, nil
|
|
}
|
|
|
|
_ = c.AddChain(&nftables.Chain{
|
|
Name: cinfo.name,
|
|
Table: cinfo.table,
|
|
Type: cinfo.chainType,
|
|
Hooknum: cinfo.chainHook,
|
|
Priority: cinfo.chainPriority,
|
|
Policy: cinfo.chainPolicy,
|
|
})
|
|
|
|
if err := c.Flush(); err != nil {
|
|
return nil, fmt.Errorf("add chain: %w", err)
|
|
}
|
|
|
|
return chain, nil
|
|
}
|
|
|
|
// NetfilterRunner abstracts helpers to run netfilter commands. It is
|
|
// implemented by linuxfw.IPTablesRunner and linuxfw.NfTablesRunner.
|
|
type NetfilterRunner interface {
|
|
// AddLoopbackRule adds a rule to permit loopback traffic to addr. This rule
|
|
// is added only if it does not already exist.
|
|
AddLoopbackRule(addr netip.Addr) error
|
|
|
|
// DelLoopbackRule removes the rule added by AddLoopbackRule.
|
|
DelLoopbackRule(addr netip.Addr) error
|
|
|
|
// AddHooks adds rules to conventional chains like "FORWARD", "INPUT" and
|
|
// "POSTROUTING" to jump from those chains to tailscale chains.
|
|
AddHooks() error
|
|
|
|
// DelHooks deletes rules added by AddHooks.
|
|
DelHooks(logf logger.Logf) error
|
|
|
|
// AddChains creates custom Tailscale chains.
|
|
AddChains() error
|
|
|
|
// DelChains removes chains added by AddChains.
|
|
DelChains() error
|
|
|
|
// AddBase adds rules reused by different other rules.
|
|
AddBase(tunname string) error
|
|
|
|
// DelBase removes rules added by AddBase.
|
|
DelBase() error
|
|
|
|
// AddSNATRule adds the netfilter rule to SNAT incoming traffic over
|
|
// the Tailscale interface destined for local subnets. An error is
|
|
// returned if the rule already exists.
|
|
AddSNATRule() error
|
|
|
|
// DelSNATRule removes the rule added by AddSNATRule.
|
|
DelSNATRule() error
|
|
|
|
// HasIPV6 reports true if the system supports IPv6.
|
|
HasIPV6() bool
|
|
|
|
// HasIPV6NAT reports true if the system supports IPv6 NAT.
|
|
HasIPV6NAT() bool
|
|
|
|
// AddDNATRule adds a rule to the nat/PREROUTING chain to DNAT traffic
|
|
// destined for the given original destination to the given new destination.
|
|
// This is used to forward all traffic destined for the Tailscale interface
|
|
// to the provided destination, as used in the Kubernetes ingress proxies.
|
|
AddDNATRule(origDst, dst netip.Addr) error
|
|
|
|
// AddSNATRuleForDst adds a rule to the nat/POSTROUTING chain to SNAT
|
|
// traffic destined for dst to src.
|
|
// This is used to forward traffic destined for the local machine over
|
|
// the Tailscale interface, as used in the Kubernetes egress proxies.
|
|
AddSNATRuleForDst(src, dst netip.Addr) error
|
|
|
|
// DNATNonTailscaleTraffic adds a rule to the nat/PREROUTING chain to DNAT
|
|
// all traffic inbound from any interface except exemptInterface to dst.
|
|
// This is used to forward traffic destined for the local machine over
|
|
// the Tailscale interface, as used in the Kubernetes egress proxies.//
|
|
DNATNonTailscaleTraffic(exemptInterface string, dst netip.Addr) error
|
|
|
|
// ClampMSSToPMTU adds a rule to the mangle/FORWARD chain to clamp MSS for
|
|
// traffic destined for the provided tun interface.
|
|
ClampMSSToPMTU(tun string, addr netip.Addr) error
|
|
}
|
|
|
|
// New creates a NetfilterRunner using either nftables or iptables.
|
|
// As nftables is still experimental, iptables will be used unless TS_DEBUG_USE_NETLINK_NFTABLES is set.
|
|
func New(logf logger.Logf) (NetfilterRunner, error) {
|
|
mode := detectFirewallMode(logf)
|
|
switch mode {
|
|
case FirewallModeIPTables:
|
|
return newIPTablesRunner(logf)
|
|
case FirewallModeNfTables:
|
|
return newNfTablesRunner(logf)
|
|
default:
|
|
return nil, fmt.Errorf("unknown firewall mode %v", mode)
|
|
}
|
|
}
|
|
|
|
// newNfTablesRunner creates a new nftablesRunner without guaranteeing
|
|
// the existence of the tables and chains.
|
|
func newNfTablesRunner(logf logger.Logf) (*nftablesRunner, error) {
|
|
conn, err := nftables.New()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("nftables connection: %w", err)
|
|
}
|
|
nft4 := &nftable{Proto: nftables.TableFamilyIPv4}
|
|
|
|
v6err := checkIPv6(logf)
|
|
if v6err != nil {
|
|
logf("disabling tunneled IPv6 due to system IPv6 config: %v", v6err)
|
|
}
|
|
supportsV6 := v6err == nil
|
|
supportsV6NAT := supportsV6 && checkSupportsV6NAT()
|
|
|
|
var nft6 *nftable
|
|
if supportsV6 {
|
|
logf("v6nat availability: %v", supportsV6NAT)
|
|
nft6 = &nftable{Proto: nftables.TableFamilyIPv6}
|
|
}
|
|
|
|
// TODO(KevinLiang10): convert iptables rule to nftable rules if they exist in the iptables
|
|
|
|
return &nftablesRunner{
|
|
conn: conn,
|
|
nft4: nft4,
|
|
nft6: nft6,
|
|
v6Available: supportsV6,
|
|
v6NATAvailable: supportsV6NAT,
|
|
}, nil
|
|
}
|
|
|
|
// newLoadSaddrExpr creates a new nftables expression that loads the source
|
|
// address of the packet into the given register.
|
|
func newLoadSaddrExpr(proto nftables.TableFamily, destReg uint32) (expr.Any, error) {
|
|
switch proto {
|
|
case nftables.TableFamilyIPv4:
|
|
return &expr.Payload{
|
|
DestRegister: destReg,
|
|
Base: expr.PayloadBaseNetworkHeader,
|
|
Offset: 12,
|
|
Len: 4,
|
|
}, nil
|
|
case nftables.TableFamilyIPv6:
|
|
return &expr.Payload{
|
|
DestRegister: destReg,
|
|
Base: expr.PayloadBaseNetworkHeader,
|
|
Offset: 8,
|
|
Len: 16,
|
|
}, nil
|
|
default:
|
|
return nil, fmt.Errorf("table family %v is neither IPv4 nor IPv6", proto)
|
|
}
|
|
}
|
|
|
|
// HasIPV6 reports true if the system supports IPv6.
|
|
func (n *nftablesRunner) HasIPV6() bool {
|
|
return n.v6Available
|
|
}
|
|
|
|
// HasIPV6NAT returns true if the system supports IPv6 NAT.
|
|
func (n *nftablesRunner) HasIPV6NAT() bool {
|
|
return n.v6NATAvailable
|
|
}
|
|
|
|
// findRule iterates through the rules to find the rule with matching expressions.
|
|
func findRule(conn *nftables.Conn, rule *nftables.Rule) (*nftables.Rule, error) {
|
|
rules, err := conn.GetRules(rule.Table, rule.Chain)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("get nftables rules: %w", err)
|
|
}
|
|
if len(rules) == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
ruleLoop:
|
|
for _, r := range rules {
|
|
if len(r.Exprs) != len(rule.Exprs) {
|
|
continue
|
|
}
|
|
|
|
for i, e := range r.Exprs {
|
|
// Skip counter expressions, as they will not match.
|
|
if _, ok := e.(*expr.Counter); ok {
|
|
continue
|
|
}
|
|
if !reflect.DeepEqual(e, rule.Exprs[i]) {
|
|
continue ruleLoop
|
|
}
|
|
}
|
|
return r, nil
|
|
}
|
|
|
|
return nil, nil
|
|
}
|
|
|
|
func createLoopbackRule(
|
|
proto nftables.TableFamily,
|
|
table *nftables.Table,
|
|
chain *nftables.Chain,
|
|
addr netip.Addr,
|
|
) (*nftables.Rule, error) {
|
|
saddrExpr, err := newLoadSaddrExpr(proto, 1)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("newLoadSaddrExpr: %w", err)
|
|
}
|
|
loopBackRule := &nftables.Rule{
|
|
Table: table,
|
|
Chain: chain,
|
|
Exprs: []expr.Any{
|
|
&expr.Meta{
|
|
Key: expr.MetaKeyIIFNAME,
|
|
Register: 1,
|
|
},
|
|
&expr.Cmp{
|
|
Op: expr.CmpOpEq,
|
|
Register: 1,
|
|
Data: []byte("lo"),
|
|
},
|
|
saddrExpr,
|
|
&expr.Cmp{
|
|
Op: expr.CmpOpEq,
|
|
Register: 1,
|
|
Data: addr.AsSlice(),
|
|
},
|
|
&expr.Counter{},
|
|
&expr.Verdict{
|
|
Kind: expr.VerdictAccept,
|
|
},
|
|
},
|
|
}
|
|
return loopBackRule, nil
|
|
}
|
|
|
|
// insertLoopbackRule inserts the TS loop back rule into
|
|
// the given chain as the first rule if it does not exist.
|
|
func insertLoopbackRule(
|
|
conn *nftables.Conn, proto nftables.TableFamily,
|
|
table *nftables.Table, chain *nftables.Chain, addr netip.Addr) error {
|
|
|
|
loopBackRule, err := createLoopbackRule(proto, table, chain, addr)
|
|
if err != nil {
|
|
return fmt.Errorf("create loopback rule: %w", err)
|
|
}
|
|
|
|
// If TestDial is set, we are running in test mode and we should not
|
|
// find rule because header will mismatch.
|
|
if conn.TestDial == nil {
|
|
// Check if the rule already exists.
|
|
rule, err := findRule(conn, loopBackRule)
|
|
if err != nil {
|
|
return fmt.Errorf("find rule: %w", err)
|
|
}
|
|
if rule != nil {
|
|
// Rule already exists, no need to insert.
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// This inserts the rule to the top of the chain
|
|
_ = conn.InsertRule(loopBackRule)
|
|
|
|
if err = conn.Flush(); err != nil {
|
|
return fmt.Errorf("insert rule: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// getNFTByAddr returns the nftables with correct IP family
|
|
// that we will be using for the given address.
|
|
func (n *nftablesRunner) getNFTByAddr(addr netip.Addr) *nftable {
|
|
if addr.Is6() {
|
|
return n.nft6
|
|
}
|
|
return n.nft4
|
|
}
|
|
|
|
// AddLoopbackRule adds an nftables rule to permit loopback traffic to
|
|
// a local Tailscale IP. This rule is added only if it does not already exist.
|
|
func (n *nftablesRunner) AddLoopbackRule(addr netip.Addr) error {
|
|
nf := n.getNFTByAddr(addr)
|
|
|
|
inputChain, err := getChainFromTable(n.conn, nf.Filter, chainNameInput)
|
|
if err != nil {
|
|
return fmt.Errorf("get input chain: %w", err)
|
|
}
|
|
|
|
if err := insertLoopbackRule(n.conn, nf.Proto, nf.Filter, inputChain, addr); err != nil {
|
|
return fmt.Errorf("add loopback rule: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// DelLoopbackRule removes the nftables rule permitting loopback
|
|
// traffic to a Tailscale IP.
|
|
func (n *nftablesRunner) DelLoopbackRule(addr netip.Addr) error {
|
|
nf := n.getNFTByAddr(addr)
|
|
|
|
inputChain, err := getChainFromTable(n.conn, nf.Filter, chainNameInput)
|
|
if err != nil {
|
|
return fmt.Errorf("get input chain: %w", err)
|
|
}
|
|
|
|
loopBackRule, err := createLoopbackRule(nf.Proto, nf.Filter, inputChain, addr)
|
|
if err != nil {
|
|
return fmt.Errorf("create loopback rule: %w", err)
|
|
}
|
|
|
|
existingLoopBackRule, err := findRule(n.conn, loopBackRule)
|
|
if err != nil {
|
|
return fmt.Errorf("find loop back rule: %w", err)
|
|
}
|
|
if existingLoopBackRule == nil {
|
|
// Rule does not exist, no need to delete.
|
|
return nil
|
|
}
|
|
|
|
if err := n.conn.DelRule(existingLoopBackRule); err != nil {
|
|
return fmt.Errorf("delete rule: %w", err)
|
|
}
|
|
|
|
return n.conn.Flush()
|
|
}
|
|
|
|
// getTables gets the available nftable in nftables runner.
|
|
func (n *nftablesRunner) getTables() []*nftable {
|
|
if n.v6Available {
|
|
return []*nftable{n.nft4, n.nft6}
|
|
}
|
|
return []*nftable{n.nft4}
|
|
}
|
|
|
|
// getNATTables gets the available nftable in nftables runner.
|
|
// If the system does not support IPv6 NAT, only the IPv4 nftable
|
|
// will be returned.
|
|
func (n *nftablesRunner) getNATTables() []*nftable {
|
|
if n.v6NATAvailable {
|
|
return n.getTables()
|
|
}
|
|
return []*nftable{n.nft4}
|
|
}
|
|
|
|
// AddChains creates custom Tailscale chains in netfilter via nftables
|
|
// if the ts-chain doesn't already exist.
|
|
func (n *nftablesRunner) AddChains() error {
|
|
polAccept := nftables.ChainPolicyAccept
|
|
for _, table := range n.getTables() {
|
|
// Create the filter table if it doesn't exist, this table name is the same
|
|
// as the name used by iptables-nft and ufw. We install rules into the
|
|
// same conventional table so that `accept` verdicts from our jump
|
|
// chains are conclusive.
|
|
filter, err := createTableIfNotExist(n.conn, table.Proto, "filter")
|
|
if err != nil {
|
|
return fmt.Errorf("create table: %w", err)
|
|
}
|
|
table.Filter = filter
|
|
// Adding the "conventional chains" that are used by iptables-nft and ufw.
|
|
if err = createChainIfNotExist(n.conn, chainInfo{filter, "FORWARD", nftables.ChainTypeFilter, nftables.ChainHookForward, nftables.ChainPriorityFilter, &polAccept}); err != nil {
|
|
return fmt.Errorf("create forward chain: %w", err)
|
|
}
|
|
if err = createChainIfNotExist(n.conn, chainInfo{filter, "INPUT", nftables.ChainTypeFilter, nftables.ChainHookInput, nftables.ChainPriorityFilter, &polAccept}); err != nil {
|
|
return fmt.Errorf("create input chain: %w", err)
|
|
}
|
|
// Adding the tailscale chains that contain our rules.
|
|
if err = createChainIfNotExist(n.conn, chainInfo{filter, chainNameForward, chainTypeRegular, nil, nil, nil}); err != nil {
|
|
return fmt.Errorf("create forward chain: %w", err)
|
|
}
|
|
if err = createChainIfNotExist(n.conn, chainInfo{filter, chainNameInput, chainTypeRegular, nil, nil, nil}); err != nil {
|
|
return fmt.Errorf("create input chain: %w", err)
|
|
}
|
|
}
|
|
|
|
for _, table := range n.getNATTables() {
|
|
// Create the nat table if it doesn't exist, this table name is the same
|
|
// as the name used by iptables-nft and ufw. We install rules into the
|
|
// same conventional table so that `accept` verdicts from our jump
|
|
// chains are conclusive.
|
|
nat, err := createTableIfNotExist(n.conn, table.Proto, "nat")
|
|
if err != nil {
|
|
return fmt.Errorf("create table: %w", err)
|
|
}
|
|
table.Nat = nat
|
|
// Adding the "conventional chains" that are used by iptables-nft and ufw.
|
|
if err = createChainIfNotExist(n.conn, chainInfo{nat, "POSTROUTING", nftables.ChainTypeNAT, nftables.ChainHookPostrouting, nftables.ChainPriorityNATSource, &polAccept}); err != nil {
|
|
return fmt.Errorf("create postrouting chain: %w", err)
|
|
}
|
|
// Adding the tailscale chain that contains our rules.
|
|
if err = createChainIfNotExist(n.conn, chainInfo{nat, chainNamePostrouting, chainTypeRegular, nil, nil, nil}); err != nil {
|
|
return fmt.Errorf("create postrouting chain: %w", err)
|
|
}
|
|
}
|
|
|
|
return n.conn.Flush()
|
|
}
|
|
|
|
// These are dummy chains and tables we create to detect if nftables is
|
|
// available. We create them, then delete them. If we can create and delete
|
|
// them, then we can use nftables. If we can't, then we assume that we're
|
|
// running on a system that doesn't support nftables. See
|
|
// createDummyPostroutingChains.
|
|
const (
|
|
tsDummyChainName = "ts-test-postrouting"
|
|
tsDummyTableName = "ts-test-nat"
|
|
)
|
|
|
|
// createDummyPostroutingChains creates dummy postrouting chains in netfilter
|
|
// via netfilter via nftables, as a last resort measure to detect that nftables
|
|
// can be used. It cleans up the dummy chains after creation.
|
|
func (n *nftablesRunner) createDummyPostroutingChains() (retErr error) {
|
|
polAccept := ptr.To(nftables.ChainPolicyAccept)
|
|
for _, table := range n.getNATTables() {
|
|
nat, err := createTableIfNotExist(n.conn, table.Proto, tsDummyTableName)
|
|
if err != nil {
|
|
return fmt.Errorf("create nat table: %w", err)
|
|
}
|
|
defer func(fm nftables.TableFamily) {
|
|
if err := deleteTableIfExists(n.conn, table.Proto, tsDummyTableName); err != nil && retErr == nil {
|
|
retErr = fmt.Errorf("delete %q table: %w", tsDummyTableName, err)
|
|
}
|
|
}(table.Proto)
|
|
|
|
table.Nat = nat
|
|
if err = createChainIfNotExist(n.conn, chainInfo{nat, tsDummyChainName, nftables.ChainTypeNAT, nftables.ChainHookPostrouting, nftables.ChainPriorityNATSource, polAccept}); err != nil {
|
|
return fmt.Errorf("create %q chain: %w", tsDummyChainName, err)
|
|
}
|
|
if err := deleteChainIfExists(n.conn, nat, tsDummyChainName); err != nil {
|
|
return fmt.Errorf("delete %q chain: %w", tsDummyChainName, err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// deleteChainIfExists deletes a chain if it exists.
|
|
func deleteChainIfExists(c *nftables.Conn, table *nftables.Table, name string) error {
|
|
chain, err := getChainFromTable(c, table, name)
|
|
if err != nil && !errors.Is(err, errorChainNotFound{table.Name, name}) {
|
|
return fmt.Errorf("get chain: %w", err)
|
|
} else if err != nil {
|
|
// If the chain doesn't exist, we don't need to delete it.
|
|
return nil
|
|
}
|
|
|
|
c.FlushChain(chain)
|
|
c.DelChain(chain)
|
|
|
|
if err := c.Flush(); err != nil {
|
|
return fmt.Errorf("flush and delete chain: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// DelChains removes the custom Tailscale chains from netfilter via nftables.
|
|
func (n *nftablesRunner) DelChains() error {
|
|
for _, table := range n.getTables() {
|
|
if err := deleteChainIfExists(n.conn, table.Filter, chainNameForward); err != nil {
|
|
return fmt.Errorf("delete chain: %w", err)
|
|
}
|
|
if err := deleteChainIfExists(n.conn, table.Filter, chainNameInput); err != nil {
|
|
return fmt.Errorf("delete chain: %w", err)
|
|
}
|
|
}
|
|
|
|
if err := deleteChainIfExists(n.conn, n.nft4.Nat, chainNamePostrouting); err != nil {
|
|
return fmt.Errorf("delete chain: %w", err)
|
|
}
|
|
|
|
if n.v6NATAvailable {
|
|
if err := deleteChainIfExists(n.conn, n.nft6.Nat, chainNamePostrouting); err != nil {
|
|
return fmt.Errorf("delete chain: %w", err)
|
|
}
|
|
}
|
|
|
|
if err := n.conn.Flush(); err != nil {
|
|
return fmt.Errorf("flush: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// createHookRule creates a rule to jump from a hooked chain to a regular chain.
|
|
func createHookRule(table *nftables.Table, fromChain *nftables.Chain, toChainName string) *nftables.Rule {
|
|
exprs := []expr.Any{
|
|
&expr.Counter{},
|
|
&expr.Verdict{
|
|
Kind: expr.VerdictJump,
|
|
Chain: toChainName,
|
|
},
|
|
}
|
|
|
|
rule := &nftables.Rule{
|
|
Table: table,
|
|
Chain: fromChain,
|
|
Exprs: exprs,
|
|
}
|
|
|
|
return rule
|
|
}
|
|
|
|
// addHookRule adds a rule to jump from a hooked chain to a regular chain at top of the hooked chain.
|
|
func addHookRule(conn *nftables.Conn, table *nftables.Table, fromChain *nftables.Chain, toChainName string) error {
|
|
rule := createHookRule(table, fromChain, toChainName)
|
|
_ = conn.InsertRule(rule)
|
|
|
|
if err := conn.Flush(); err != nil {
|
|
return fmt.Errorf("flush add rule: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// AddHooks is adding rules to conventional chains like "FORWARD", "INPUT" and "POSTROUTING"
|
|
// in tables and jump from those chains to tailscale chains.
|
|
func (n *nftablesRunner) AddHooks() error {
|
|
conn := n.conn
|
|
|
|
for _, table := range n.getTables() {
|
|
inputChain, err := getChainFromTable(conn, table.Filter, "INPUT")
|
|
if err != nil {
|
|
return fmt.Errorf("get INPUT chain: %w", err)
|
|
}
|
|
err = addHookRule(conn, table.Filter, inputChain, chainNameInput)
|
|
if err != nil {
|
|
return fmt.Errorf("Addhook: %w", err)
|
|
}
|
|
forwardChain, err := getChainFromTable(conn, table.Filter, "FORWARD")
|
|
if err != nil {
|
|
return fmt.Errorf("get FORWARD chain: %w", err)
|
|
}
|
|
err = addHookRule(conn, table.Filter, forwardChain, chainNameForward)
|
|
if err != nil {
|
|
return fmt.Errorf("Addhook: %w", err)
|
|
}
|
|
}
|
|
|
|
for _, table := range n.getNATTables() {
|
|
postroutingChain, err := getChainFromTable(conn, table.Nat, "POSTROUTING")
|
|
if err != nil {
|
|
return fmt.Errorf("get INPUT chain: %w", err)
|
|
}
|
|
err = addHookRule(conn, table.Nat, postroutingChain, chainNamePostrouting)
|
|
if err != nil {
|
|
return fmt.Errorf("Addhook: %w", err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// delHookRule deletes a rule that jumps from a hooked chain to a regular chain.
|
|
func delHookRule(conn *nftables.Conn, table *nftables.Table, fromChain *nftables.Chain, toChainName string) error {
|
|
rule := createHookRule(table, fromChain, toChainName)
|
|
existingRule, err := findRule(conn, rule)
|
|
if err != nil {
|
|
return fmt.Errorf("Failed to find hook rule: %w", err)
|
|
}
|
|
|
|
if existingRule == nil {
|
|
return nil
|
|
}
|
|
|
|
_ = conn.DelRule(existingRule)
|
|
|
|
if err := conn.Flush(); err != nil {
|
|
return fmt.Errorf("flush del hook rule: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// DelHooks is deleting the rules added to conventional chains to jump to tailscale chains.
|
|
func (n *nftablesRunner) DelHooks(logf logger.Logf) error {
|
|
conn := n.conn
|
|
|
|
for _, table := range n.getTables() {
|
|
inputChain, err := getChainFromTable(conn, table.Filter, "INPUT")
|
|
if err != nil {
|
|
return fmt.Errorf("get INPUT chain: %w", err)
|
|
}
|
|
err = delHookRule(conn, table.Filter, inputChain, chainNameInput)
|
|
if err != nil {
|
|
return fmt.Errorf("delhook: %w", err)
|
|
}
|
|
forwardChain, err := getChainFromTable(conn, table.Filter, "FORWARD")
|
|
if err != nil {
|
|
return fmt.Errorf("get FORWARD chain: %w", err)
|
|
}
|
|
err = delHookRule(conn, table.Filter, forwardChain, chainNameForward)
|
|
if err != nil {
|
|
return fmt.Errorf("delhook: %w", err)
|
|
}
|
|
}
|
|
|
|
for _, table := range n.getNATTables() {
|
|
postroutingChain, err := getChainFromTable(conn, table.Nat, "POSTROUTING")
|
|
if err != nil {
|
|
return fmt.Errorf("get INPUT chain: %w", err)
|
|
}
|
|
err = delHookRule(conn, table.Nat, postroutingChain, chainNamePostrouting)
|
|
if err != nil {
|
|
return fmt.Errorf("delhook: %w", err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// maskof returns the mask of the given prefix in big endian bytes.
|
|
func maskof(pfx netip.Prefix) []byte {
|
|
mask := make([]byte, 4)
|
|
binary.BigEndian.PutUint32(mask, ^(uint32(0xffff_ffff) >> pfx.Bits()))
|
|
return mask
|
|
}
|
|
|
|
// createRangeRule creates a rule that matches packets with source IP from the give
|
|
// range (like CGNAT range or ChromeOSVM range) and the interface is not the tunname,
|
|
// and makes the given decision. Only IPv4 is supported.
|
|
func createRangeRule(
|
|
table *nftables.Table, chain *nftables.Chain,
|
|
tunname string, rng netip.Prefix, decision expr.VerdictKind,
|
|
) (*nftables.Rule, error) {
|
|
if rng.Addr().Is6() {
|
|
return nil, errors.New("IPv6 is not supported")
|
|
}
|
|
saddrExpr, err := newLoadSaddrExpr(nftables.TableFamilyIPv4, 1)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("newLoadSaddrExpr: %w", err)
|
|
}
|
|
netip := rng.Addr().AsSlice()
|
|
mask := maskof(rng)
|
|
rule := &nftables.Rule{
|
|
Table: table,
|
|
Chain: chain,
|
|
Exprs: []expr.Any{
|
|
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
|
&expr.Cmp{
|
|
Op: expr.CmpOpNeq,
|
|
Register: 1,
|
|
Data: []byte(tunname),
|
|
},
|
|
saddrExpr,
|
|
&expr.Bitwise{
|
|
SourceRegister: 1,
|
|
DestRegister: 1,
|
|
Len: 4,
|
|
Mask: mask,
|
|
Xor: []byte{0x00, 0x00, 0x00, 0x00},
|
|
},
|
|
&expr.Cmp{
|
|
Op: expr.CmpOpEq,
|
|
Register: 1,
|
|
Data: netip,
|
|
},
|
|
&expr.Counter{},
|
|
&expr.Verdict{
|
|
Kind: decision,
|
|
},
|
|
},
|
|
}
|
|
return rule, nil
|
|
|
|
}
|
|
|
|
// addReturnChromeOSVMRangeRule adds a rule to return if the source IP
|
|
// is in the ChromeOS VM range.
|
|
func addReturnChromeOSVMRangeRule(c *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error {
|
|
rule, err := createRangeRule(table, chain, tunname, tsaddr.ChromeOSVMRange(), expr.VerdictReturn)
|
|
if err != nil {
|
|
return fmt.Errorf("create rule: %w", err)
|
|
}
|
|
_ = c.AddRule(rule)
|
|
if err = c.Flush(); err != nil {
|
|
return fmt.Errorf("add rule: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// addDropCGNATRangeRule adds a rule to drop if the source IP is in the
|
|
// CGNAT range.
|
|
func addDropCGNATRangeRule(c *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error {
|
|
rule, err := createRangeRule(table, chain, tunname, tsaddr.CGNATRange(), expr.VerdictDrop)
|
|
if err != nil {
|
|
return fmt.Errorf("create rule: %w", err)
|
|
}
|
|
_ = c.AddRule(rule)
|
|
if err = c.Flush(); err != nil {
|
|
return fmt.Errorf("add rule: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// createSetSubnetRouteMarkRule creates a rule to set the subnet route
|
|
// mark if the packet is from the given interface.
|
|
func createSetSubnetRouteMarkRule(table *nftables.Table, chain *nftables.Chain, tunname string) (*nftables.Rule, error) {
|
|
hexTsFwmarkMaskNeg := getTailscaleFwmarkMaskNeg()
|
|
hexTSSubnetRouteMark := getTailscaleSubnetRouteMark()
|
|
|
|
rule := &nftables.Rule{
|
|
Table: table,
|
|
Chain: chain,
|
|
Exprs: []expr.Any{
|
|
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
|
&expr.Cmp{
|
|
Op: expr.CmpOpEq,
|
|
Register: 1,
|
|
Data: []byte(tunname),
|
|
},
|
|
&expr.Counter{},
|
|
&expr.Meta{Key: expr.MetaKeyMARK, Register: 1},
|
|
&expr.Bitwise{
|
|
SourceRegister: 1,
|
|
DestRegister: 1,
|
|
Len: 4,
|
|
Mask: hexTsFwmarkMaskNeg,
|
|
Xor: hexTSSubnetRouteMark,
|
|
},
|
|
&expr.Meta{
|
|
Key: expr.MetaKeyMARK,
|
|
SourceRegister: true,
|
|
Register: 1,
|
|
},
|
|
},
|
|
}
|
|
return rule, nil
|
|
}
|
|
|
|
// addSetSubnetRouteMarkRule adds a rule to set the subnet route mark
|
|
// if the packet is from the given interface.
|
|
func addSetSubnetRouteMarkRule(c *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error {
|
|
rule, err := createSetSubnetRouteMarkRule(table, chain, tunname)
|
|
if err != nil {
|
|
return fmt.Errorf("create rule: %w", err)
|
|
}
|
|
_ = c.AddRule(rule)
|
|
|
|
if err := c.Flush(); err != nil {
|
|
return fmt.Errorf("add rule: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// createDropOutgoingPacketFromCGNATRangeRuleWithTunname creates a rule to drop
|
|
// outgoing packets from the CGNAT range.
|
|
func createDropOutgoingPacketFromCGNATRangeRuleWithTunname(table *nftables.Table, chain *nftables.Chain, tunname string) (*nftables.Rule, error) {
|
|
_, ipNet, err := net.ParseCIDR(tsaddr.CGNATRange().String())
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parse cidr: %v", err)
|
|
}
|
|
mask, err := hex.DecodeString(ipNet.Mask.String())
|
|
if err != nil {
|
|
return nil, fmt.Errorf("decode mask: %v", err)
|
|
}
|
|
netip := ipNet.IP.Mask(ipNet.Mask).To4()
|
|
saddrExpr, err := newLoadSaddrExpr(nftables.TableFamilyIPv4, 1)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("newLoadSaddrExpr: %v", err)
|
|
}
|
|
rule := &nftables.Rule{
|
|
Table: table,
|
|
Chain: chain,
|
|
Exprs: []expr.Any{
|
|
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
|
&expr.Cmp{
|
|
Op: expr.CmpOpEq,
|
|
Register: 1,
|
|
Data: []byte(tunname),
|
|
},
|
|
saddrExpr,
|
|
&expr.Bitwise{
|
|
SourceRegister: 1,
|
|
DestRegister: 1,
|
|
Len: 4,
|
|
Mask: mask,
|
|
Xor: []byte{0x00, 0x00, 0x00, 0x00},
|
|
},
|
|
&expr.Cmp{
|
|
Op: expr.CmpOpEq,
|
|
Register: 1,
|
|
Data: netip,
|
|
},
|
|
&expr.Counter{},
|
|
&expr.Verdict{
|
|
Kind: expr.VerdictDrop,
|
|
},
|
|
},
|
|
}
|
|
return rule, nil
|
|
}
|
|
|
|
// addDropOutgoingPacketFromCGNATRangeRuleWithTunname adds a rule to drop
|
|
// outgoing packets from the CGNAT range.
|
|
func addDropOutgoingPacketFromCGNATRangeRuleWithTunname(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error {
|
|
rule, err := createDropOutgoingPacketFromCGNATRangeRuleWithTunname(table, chain, tunname)
|
|
if err != nil {
|
|
return fmt.Errorf("create rule: %w", err)
|
|
}
|
|
_ = conn.AddRule(rule)
|
|
|
|
if err := conn.Flush(); err != nil {
|
|
return fmt.Errorf("add rule: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// createAcceptOutgoingPacketRule creates a rule to accept outgoing packets
|
|
// from the given interface.
|
|
func createAcceptOutgoingPacketRule(table *nftables.Table, chain *nftables.Chain, tunname string) *nftables.Rule {
|
|
return &nftables.Rule{
|
|
Table: table,
|
|
Chain: chain,
|
|
Exprs: []expr.Any{
|
|
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
|
&expr.Cmp{
|
|
Op: expr.CmpOpEq,
|
|
Register: 1,
|
|
Data: []byte(tunname),
|
|
},
|
|
&expr.Counter{},
|
|
&expr.Verdict{
|
|
Kind: expr.VerdictAccept,
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
// addAcceptOutgoingPacketRule adds a rule to accept outgoing packets
|
|
// from the given interface.
|
|
func addAcceptOutgoingPacketRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error {
|
|
rule := createAcceptOutgoingPacketRule(table, chain, tunname)
|
|
_ = conn.AddRule(rule)
|
|
|
|
if err := conn.Flush(); err != nil {
|
|
return fmt.Errorf("flush add rule: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// createAcceptIncomingPacketRule creates a rule to accept incoming packets to
|
|
// the given interface.
|
|
func createAcceptIncomingPacketRule(table *nftables.Table, chain *nftables.Chain, tunname string) *nftables.Rule {
|
|
return &nftables.Rule{
|
|
Table: table,
|
|
Chain: chain,
|
|
Exprs: []expr.Any{
|
|
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
|
&expr.Cmp{
|
|
Op: expr.CmpOpEq,
|
|
Register: 1,
|
|
Data: []byte(tunname),
|
|
},
|
|
&expr.Counter{},
|
|
&expr.Verdict{
|
|
Kind: expr.VerdictAccept,
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
func addAcceptIncomingPacketRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error {
|
|
rule := createAcceptIncomingPacketRule(table, chain, tunname)
|
|
_ = conn.AddRule(rule)
|
|
|
|
if err := conn.Flush(); err != nil {
|
|
return fmt.Errorf("flush add rule: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// AddBase adds some basic processing rules.
|
|
func (n *nftablesRunner) AddBase(tunname string) error {
|
|
if err := n.addBase4(tunname); err != nil {
|
|
return fmt.Errorf("add base v4: %w", err)
|
|
}
|
|
if n.HasIPV6() {
|
|
if err := n.addBase6(tunname); err != nil {
|
|
return fmt.Errorf("add base v6: %w", err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// addBase4 adds some basic IPv4 processing rules.
|
|
func (n *nftablesRunner) addBase4(tunname string) error {
|
|
conn := n.conn
|
|
|
|
inputChain, err := getChainFromTable(conn, n.nft4.Filter, chainNameInput)
|
|
if err != nil {
|
|
return fmt.Errorf("get input chain v4: %v", err)
|
|
}
|
|
if err = addReturnChromeOSVMRangeRule(conn, n.nft4.Filter, inputChain, tunname); err != nil {
|
|
return fmt.Errorf("add return chromeos vm range rule v4: %w", err)
|
|
}
|
|
if err = addDropCGNATRangeRule(conn, n.nft4.Filter, inputChain, tunname); err != nil {
|
|
return fmt.Errorf("add drop cgnat range rule v4: %w", err)
|
|
}
|
|
if err = addAcceptIncomingPacketRule(conn, n.nft4.Filter, inputChain, tunname); err != nil {
|
|
return fmt.Errorf("add accept incoming packet rule v4: %w", err)
|
|
}
|
|
|
|
forwardChain, err := getChainFromTable(conn, n.nft4.Filter, chainNameForward)
|
|
if err != nil {
|
|
return fmt.Errorf("get forward chain v4: %v", err)
|
|
}
|
|
|
|
if err = addSetSubnetRouteMarkRule(conn, n.nft4.Filter, forwardChain, tunname); err != nil {
|
|
return fmt.Errorf("add set subnet route mark rule v4: %w", err)
|
|
}
|
|
|
|
if err = addMatchSubnetRouteMarkRule(conn, n.nft4.Filter, forwardChain, Accept); err != nil {
|
|
return fmt.Errorf("add match subnet route mark rule v4: %w", err)
|
|
}
|
|
|
|
if err = addDropOutgoingPacketFromCGNATRangeRuleWithTunname(conn, n.nft4.Filter, forwardChain, tunname); err != nil {
|
|
return fmt.Errorf("add drop outgoing packet from cgnat range rule v4: %w", err)
|
|
}
|
|
|
|
if err = addAcceptOutgoingPacketRule(conn, n.nft4.Filter, forwardChain, tunname); err != nil {
|
|
return fmt.Errorf("add accept outgoing packet rule v4: %w", err)
|
|
}
|
|
|
|
if err = conn.Flush(); err != nil {
|
|
return fmt.Errorf("flush base v4: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// addBase6 adds some basic IPv6 processing rules.
|
|
func (n *nftablesRunner) addBase6(tunname string) error {
|
|
conn := n.conn
|
|
|
|
inputChain, err := getChainFromTable(conn, n.nft6.Filter, chainNameInput)
|
|
if err != nil {
|
|
return fmt.Errorf("get input chain v4: %v", err)
|
|
}
|
|
if err = addAcceptIncomingPacketRule(conn, n.nft6.Filter, inputChain, tunname); err != nil {
|
|
return fmt.Errorf("add accept incoming packet rule v6: %w", err)
|
|
}
|
|
|
|
forwardChain, err := getChainFromTable(conn, n.nft6.Filter, chainNameForward)
|
|
if err != nil {
|
|
return fmt.Errorf("get forward chain v6: %w", err)
|
|
}
|
|
|
|
if err = addSetSubnetRouteMarkRule(conn, n.nft6.Filter, forwardChain, tunname); err != nil {
|
|
return fmt.Errorf("add set subnet route mark rule v6: %w", err)
|
|
}
|
|
|
|
if err = addMatchSubnetRouteMarkRule(conn, n.nft6.Filter, forwardChain, Accept); err != nil {
|
|
return fmt.Errorf("add match subnet route mark rule v6: %w", err)
|
|
}
|
|
|
|
if err = addAcceptOutgoingPacketRule(conn, n.nft6.Filter, forwardChain, tunname); err != nil {
|
|
return fmt.Errorf("add accept outgoing packet rule v6: %w", err)
|
|
}
|
|
|
|
if err = conn.Flush(); err != nil {
|
|
return fmt.Errorf("flush base v6: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// DelBase empties, but does not remove, custom Tailscale chains from
|
|
// netfilter via iptables.
|
|
func (n *nftablesRunner) DelBase() error {
|
|
conn := n.conn
|
|
|
|
for _, table := range n.getTables() {
|
|
inputChain, err := getChainFromTable(conn, table.Filter, chainNameInput)
|
|
if err != nil {
|
|
return fmt.Errorf("get input chain: %v", err)
|
|
}
|
|
conn.FlushChain(inputChain)
|
|
forwardChain, err := getChainFromTable(conn, table.Filter, chainNameForward)
|
|
if err != nil {
|
|
return fmt.Errorf("get forward chain: %v", err)
|
|
}
|
|
conn.FlushChain(forwardChain)
|
|
}
|
|
|
|
for _, table := range n.getNATTables() {
|
|
postrouteChain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting)
|
|
if err != nil {
|
|
return fmt.Errorf("get postrouting chain v4: %v", err)
|
|
}
|
|
conn.FlushChain(postrouteChain)
|
|
}
|
|
|
|
return conn.Flush()
|
|
}
|
|
|
|
// createMatchSubnetRouteMarkRule creates a rule that matches packets
|
|
// with the subnet route mark and takes the specified action.
|
|
func createMatchSubnetRouteMarkRule(table *nftables.Table, chain *nftables.Chain, action MatchDecision) (*nftables.Rule, error) {
|
|
hexTSFwmarkMask := getTailscaleFwmarkMask()
|
|
hexTSSubnetRouteMark := getTailscaleSubnetRouteMark()
|
|
|
|
var endAction expr.Any
|
|
endAction = &expr.Verdict{Kind: expr.VerdictAccept}
|
|
if action == Masq {
|
|
endAction = &expr.Masq{}
|
|
}
|
|
|
|
exprs := []expr.Any{
|
|
&expr.Meta{Key: expr.MetaKeyMARK, Register: 1},
|
|
&expr.Bitwise{
|
|
SourceRegister: 1,
|
|
DestRegister: 1,
|
|
Len: 4,
|
|
Mask: hexTSFwmarkMask,
|
|
Xor: []byte{0x00, 0x00, 0x00, 0x00},
|
|
},
|
|
&expr.Cmp{
|
|
Op: expr.CmpOpEq,
|
|
Register: 1,
|
|
Data: hexTSSubnetRouteMark,
|
|
},
|
|
&expr.Counter{},
|
|
endAction,
|
|
}
|
|
|
|
rule := &nftables.Rule{
|
|
Table: table,
|
|
Chain: chain,
|
|
Exprs: exprs,
|
|
}
|
|
return rule, nil
|
|
}
|
|
|
|
// addMatchSubnetRouteMarkRule adds a rule that matches packets with
|
|
// the subnet route mark and takes the specified action.
|
|
func addMatchSubnetRouteMarkRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, action MatchDecision) error {
|
|
rule, err := createMatchSubnetRouteMarkRule(table, chain, action)
|
|
if err != nil {
|
|
return fmt.Errorf("create match subnet route mark rule: %w", err)
|
|
}
|
|
_ = conn.AddRule(rule)
|
|
|
|
if err := conn.Flush(); err != nil {
|
|
return fmt.Errorf("flush add rule: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// AddSNATRule adds a netfilter rule to SNAT traffic destined for
|
|
// local subnets.
|
|
func (n *nftablesRunner) AddSNATRule() error {
|
|
conn := n.conn
|
|
|
|
for _, table := range n.getNATTables() {
|
|
chain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting)
|
|
if err != nil {
|
|
return fmt.Errorf("get postrouting chain v4: %w", err)
|
|
}
|
|
|
|
if err = addMatchSubnetRouteMarkRule(conn, table.Nat, chain, Masq); err != nil {
|
|
return fmt.Errorf("add match subnet route mark rule v4: %w", err)
|
|
}
|
|
}
|
|
|
|
if err := conn.Flush(); err != nil {
|
|
return fmt.Errorf("flush add SNAT rule: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// DelSNATRule removes the netfilter rule to SNAT traffic destined for
|
|
// local subnets. An error is returned if the rule does not exist.
|
|
func (n *nftablesRunner) DelSNATRule() error {
|
|
conn := n.conn
|
|
|
|
hexTSFwmarkMask := getTailscaleFwmarkMask()
|
|
hexTSSubnetRouteMark := getTailscaleSubnetRouteMark()
|
|
|
|
exprs := []expr.Any{
|
|
&expr.Meta{Key: expr.MetaKeyMARK, Register: 1},
|
|
&expr.Bitwise{
|
|
SourceRegister: 1,
|
|
DestRegister: 1,
|
|
Len: 4,
|
|
Mask: hexTSFwmarkMask,
|
|
},
|
|
&expr.Cmp{
|
|
Op: expr.CmpOpEq,
|
|
Register: 1,
|
|
Data: hexTSSubnetRouteMark,
|
|
},
|
|
&expr.Counter{},
|
|
&expr.Masq{},
|
|
}
|
|
|
|
for _, table := range n.getNATTables() {
|
|
chain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting)
|
|
if err != nil {
|
|
return fmt.Errorf("get postrouting chain v4: %w", err)
|
|
}
|
|
|
|
rule := &nftables.Rule{
|
|
Table: table.Nat,
|
|
Chain: chain,
|
|
Exprs: exprs,
|
|
}
|
|
|
|
SNATRule, err := findRule(conn, rule)
|
|
if err != nil {
|
|
return fmt.Errorf("find SNAT rule v4: %w", err)
|
|
}
|
|
|
|
if SNATRule != nil {
|
|
_ = conn.DelRule(SNATRule)
|
|
}
|
|
}
|
|
|
|
if err := conn.Flush(); err != nil {
|
|
return fmt.Errorf("flush del SNAT rule: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// cleanupChain removes a jump rule from hookChainName to tsChainName, and then
|
|
// the entire chain tsChainName. Errors are logged, but attempts to remove both
|
|
// the jump rule and chain continue even if one errors.
|
|
func cleanupChain(logf logger.Logf, conn *nftables.Conn, table *nftables.Table, hookChainName, tsChainName string) {
|
|
// remove the jump first, before removing the jump destination.
|
|
defaultChain, err := getChainFromTable(conn, table, hookChainName)
|
|
if err != nil && !errors.Is(err, errorChainNotFound{table.Name, hookChainName}) {
|
|
logf("cleanup: did not find default chain: %s", err)
|
|
}
|
|
if !errors.Is(err, errorChainNotFound{table.Name, hookChainName}) {
|
|
// delete hook in convention chain
|
|
_ = delHookRule(conn, table, defaultChain, tsChainName)
|
|
}
|
|
|
|
tsChain, err := getChainFromTable(conn, table, tsChainName)
|
|
if err != nil && !errors.Is(err, errorChainNotFound{table.Name, tsChainName}) {
|
|
logf("cleanup: did not find ts-chain: %s", err)
|
|
}
|
|
|
|
if tsChain != nil {
|
|
// flush and delete ts-chain
|
|
conn.FlushChain(tsChain)
|
|
conn.DelChain(tsChain)
|
|
err = conn.Flush()
|
|
logf("cleanup: delete and flush chain %s: %s", tsChainName, err)
|
|
}
|
|
}
|
|
|
|
// NfTablesCleanUp removes all Tailscale added nftables rules.
|
|
// Any errors that occur are logged to the provided logf.
|
|
func NfTablesCleanUp(logf logger.Logf) {
|
|
conn, err := nftables.New()
|
|
if err != nil {
|
|
logf("cleanup: nftables connection: %s", err)
|
|
}
|
|
|
|
tables, err := conn.ListTables() // both v4 and v6
|
|
if err != nil {
|
|
logf("cleanup: list tables: %s", err)
|
|
}
|
|
|
|
for _, table := range tables {
|
|
// These table names were used briefly in 1.48.0.
|
|
if table.Name == "ts-filter" || table.Name == "ts-nat" {
|
|
conn.DelTable(table)
|
|
if err := conn.Flush(); err != nil {
|
|
logf("cleanup: flush delete table %s: %s", table.Name, err)
|
|
}
|
|
}
|
|
|
|
if table.Name == "filter" {
|
|
cleanupChain(logf, conn, table, "INPUT", chainNameInput)
|
|
cleanupChain(logf, conn, table, "FORWARD", chainNameForward)
|
|
}
|
|
if table.Name == "nat" {
|
|
cleanupChain(logf, conn, table, "POSTROUTING", chainNamePostrouting)
|
|
}
|
|
}
|
|
}
|