mirror of
https://github.com/tailscale/tailscale.git
synced 2024-11-29 13:05:46 +00:00
978 lines
27 KiB
Go
978 lines
27 KiB
Go
|
// Copyright (c) Tailscale Inc & AUTHORS
|
||
|
// SPDX-License-Identifier: BSD-3-Clause
|
||
|
|
||
|
//go:build linux
|
||
|
|
||
|
package linuxfw
|
||
|
|
||
|
import (
|
||
|
"encoding/binary"
|
||
|
"encoding/hex"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"net"
|
||
|
"net/netip"
|
||
|
"reflect"
|
||
|
|
||
|
"github.com/google/nftables"
|
||
|
"github.com/google/nftables/expr"
|
||
|
"tailscale.com/net/tsaddr"
|
||
|
"tailscale.com/types/logger"
|
||
|
)
|
||
|
|
||
|
const (
|
||
|
chainNameForward = "ts-forward"
|
||
|
chainNameInput = "ts-input"
|
||
|
chainNamePostrouting = "ts-postrouting"
|
||
|
)
|
||
|
|
||
|
type chainInfo struct {
|
||
|
table *nftables.Table
|
||
|
name string
|
||
|
chainType nftables.ChainType
|
||
|
chainHook *nftables.ChainHook
|
||
|
chainPriority *nftables.ChainPriority
|
||
|
}
|
||
|
|
||
|
type nftable struct {
|
||
|
Proto nftables.TableFamily
|
||
|
Filter *nftables.Table
|
||
|
Nat *nftables.Table
|
||
|
}
|
||
|
|
||
|
type nftablesRunner struct {
|
||
|
conn *nftables.Conn
|
||
|
nft4 *nftable
|
||
|
nft6 *nftable
|
||
|
|
||
|
v6Available bool
|
||
|
v6NATAvailable bool
|
||
|
}
|
||
|
|
||
|
// createTableIfNotExist creates a nftables table via connection c if it does not exist within the given family.
|
||
|
func createTableIfNotExist(c *nftables.Conn, family nftables.TableFamily, name string) (*nftables.Table, error) {
|
||
|
tables, err := c.ListTables()
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("get tables: %w", err)
|
||
|
}
|
||
|
for _, table := range tables {
|
||
|
if table.Name == name && table.Family == family {
|
||
|
return table, nil
|
||
|
}
|
||
|
}
|
||
|
|
||
|
t := c.AddTable(&nftables.Table{
|
||
|
Family: family,
|
||
|
Name: name,
|
||
|
})
|
||
|
if err := c.Flush(); err != nil {
|
||
|
return nil, fmt.Errorf("add table: %w", err)
|
||
|
}
|
||
|
return t, nil
|
||
|
}
|
||
|
|
||
|
type errorChainNotFound struct {
|
||
|
chainName string
|
||
|
tableName string
|
||
|
}
|
||
|
|
||
|
func (e errorChainNotFound) Error() string {
|
||
|
return fmt.Sprintf("chain %s not found in table %s", e.chainName, e.tableName)
|
||
|
}
|
||
|
|
||
|
// getChainFromTable returns the chain with the given name from the given table.
|
||
|
// Note that a chain name is unique within a table.
|
||
|
func getChainFromTable(c *nftables.Conn, table *nftables.Table, name string) (*nftables.Chain, error) {
|
||
|
chains, err := c.ListChainsOfTableFamily(table.Family)
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("list chains: %w", err)
|
||
|
}
|
||
|
|
||
|
for _, chain := range chains {
|
||
|
// Table family is already checked so table name is unique
|
||
|
if chain.Table.Name == table.Name && chain.Name == name {
|
||
|
return chain, nil
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return nil, errorChainNotFound{table.Name, name}
|
||
|
}
|
||
|
|
||
|
// getChainsFromTable returns all chains from the given table.
|
||
|
func getChainsFromTable(c *nftables.Conn, table *nftables.Table) ([]*nftables.Chain, error) {
|
||
|
chains, err := c.ListChainsOfTableFamily(table.Family)
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("list chains: %w", err)
|
||
|
}
|
||
|
|
||
|
var ret []*nftables.Chain
|
||
|
for _, chain := range chains {
|
||
|
// Table family is already checked so table name is unique
|
||
|
if chain.Table.Name == table.Name {
|
||
|
ret = append(ret, chain)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return ret, nil
|
||
|
}
|
||
|
|
||
|
// createChainIfNotExist creates a chain with the given name in the given table
|
||
|
// if it does not exist.
|
||
|
func createChainIfNotExist(c *nftables.Conn, cinfo chainInfo) error {
|
||
|
chain, err := getChainFromTable(c, cinfo.table, cinfo.name)
|
||
|
if err != nil && !errors.Is(err, errorChainNotFound{cinfo.table.Name, cinfo.name}) {
|
||
|
return fmt.Errorf("get chain: %w", err)
|
||
|
} else if err == nil {
|
||
|
// Chain already exists
|
||
|
if chain.Type != cinfo.chainType || chain.Hooknum != cinfo.chainHook || chain.Priority != cinfo.chainPriority {
|
||
|
return fmt.Errorf("chain %s already exists with different type/hook/priority", cinfo.name)
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
_ = c.AddChain(&nftables.Chain{
|
||
|
Name: cinfo.name,
|
||
|
Table: cinfo.table,
|
||
|
Type: cinfo.chainType,
|
||
|
Hooknum: cinfo.chainHook,
|
||
|
Priority: cinfo.chainPriority,
|
||
|
})
|
||
|
|
||
|
if err := c.Flush(); err != nil {
|
||
|
return fmt.Errorf("add chain: %w", err)
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// NewNfTablesRunner creates a new nftablesRunner without guaranteeing
|
||
|
// the existence of the tables and chains.
|
||
|
func NewNfTablesRunner(logf logger.Logf) (*nftablesRunner, error) {
|
||
|
conn, err := nftables.New()
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("nftables connection: %w", err)
|
||
|
}
|
||
|
nft4 := &nftable{Proto: nftables.TableFamilyIPv4}
|
||
|
|
||
|
v6err := checkIPv6(logf)
|
||
|
if v6err != nil {
|
||
|
logf("disabling tunneled IPv6 due to system IPv6 config: %w", v6err)
|
||
|
}
|
||
|
supportsV6 := v6err == nil
|
||
|
supportsV6NAT := supportsV6 && checkSupportsV6NAT()
|
||
|
|
||
|
var nft6 *nftable
|
||
|
if supportsV6 {
|
||
|
logf("v6nat availability: %v", supportsV6NAT)
|
||
|
nft6 = &nftable{Proto: nftables.TableFamilyIPv6}
|
||
|
}
|
||
|
|
||
|
// TODO(KevinLiang10): convert iptables rule to nftable rules if they exist in the iptables
|
||
|
|
||
|
return &nftablesRunner{
|
||
|
conn: conn,
|
||
|
nft4: nft4,
|
||
|
nft6: nft6,
|
||
|
v6Available: supportsV6,
|
||
|
v6NATAvailable: supportsV6NAT,
|
||
|
}, nil
|
||
|
}
|
||
|
|
||
|
// newLoadSaddrExpr creates a new nftables expression that loads the source
|
||
|
// address of the packet into the given register.
|
||
|
func newLoadSaddrExpr(proto nftables.TableFamily, destReg uint32) (expr.Any, error) {
|
||
|
switch proto {
|
||
|
case nftables.TableFamilyIPv4:
|
||
|
return &expr.Payload{
|
||
|
DestRegister: destReg,
|
||
|
Base: expr.PayloadBaseNetworkHeader,
|
||
|
Offset: 12,
|
||
|
Len: 4,
|
||
|
}, nil
|
||
|
case nftables.TableFamilyIPv6:
|
||
|
return &expr.Payload{
|
||
|
DestRegister: destReg,
|
||
|
Base: expr.PayloadBaseNetworkHeader,
|
||
|
Offset: 8,
|
||
|
Len: 16,
|
||
|
}, nil
|
||
|
default:
|
||
|
return nil, fmt.Errorf("table family %v is neither IPv4 nor IPv6", proto)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// HasIPV6 returns true if the system supports IPv6.
|
||
|
func (n *nftablesRunner) HasIPV6() bool {
|
||
|
return n.v6Available
|
||
|
}
|
||
|
|
||
|
// HasIPV6NAT returns true if the system supports IPv6 NAT.
|
||
|
func (n *nftablesRunner) HasIPV6NAT() bool {
|
||
|
return n.v6NATAvailable
|
||
|
}
|
||
|
|
||
|
// findRule iterates through the rules to find the rule with matching expressions.
|
||
|
func findRule(conn *nftables.Conn, rule *nftables.Rule) (*nftables.Rule, error) {
|
||
|
rules, err := conn.GetRules(rule.Table, rule.Chain)
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("get nftables rules: %w", err)
|
||
|
}
|
||
|
if len(rules) == 0 {
|
||
|
return nil, nil
|
||
|
}
|
||
|
|
||
|
ruleLoop:
|
||
|
for _, r := range rules {
|
||
|
if len(r.Exprs) != len(rule.Exprs) {
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
for i, e := range r.Exprs {
|
||
|
if !reflect.DeepEqual(e, rule.Exprs[i]) {
|
||
|
continue ruleLoop
|
||
|
}
|
||
|
}
|
||
|
return r, nil
|
||
|
}
|
||
|
|
||
|
return nil, nil
|
||
|
}
|
||
|
|
||
|
func createLoopbackRule(
|
||
|
proto nftables.TableFamily,
|
||
|
table *nftables.Table,
|
||
|
chain *nftables.Chain,
|
||
|
addr netip.Addr,
|
||
|
) (*nftables.Rule, error) {
|
||
|
saddrExpr, err := newLoadSaddrExpr(proto, 1)
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("newLoadSaddrExpr: %w", err)
|
||
|
}
|
||
|
loopBackRule := &nftables.Rule{
|
||
|
Table: table,
|
||
|
Chain: chain,
|
||
|
Exprs: []expr.Any{
|
||
|
&expr.Meta{
|
||
|
Key: expr.MetaKeyIIFNAME,
|
||
|
Register: 1,
|
||
|
},
|
||
|
&expr.Cmp{
|
||
|
Op: expr.CmpOpEq,
|
||
|
Register: 1,
|
||
|
Data: []byte("lo"),
|
||
|
},
|
||
|
saddrExpr,
|
||
|
&expr.Cmp{
|
||
|
Op: expr.CmpOpEq,
|
||
|
Register: 1,
|
||
|
Data: addr.AsSlice(),
|
||
|
},
|
||
|
&expr.Counter{},
|
||
|
&expr.Verdict{
|
||
|
Kind: expr.VerdictAccept,
|
||
|
},
|
||
|
},
|
||
|
}
|
||
|
return loopBackRule, nil
|
||
|
}
|
||
|
|
||
|
// insertLoopbackRule inserts the TS loop back rule into
|
||
|
// the given chain as the first rule if it does not exist.
|
||
|
func insertLoopbackRule(
|
||
|
conn *nftables.Conn, proto nftables.TableFamily,
|
||
|
table *nftables.Table, chain *nftables.Chain, addr netip.Addr) error {
|
||
|
|
||
|
loopBackRule, err := createLoopbackRule(proto, table, chain, addr)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("create loopback rule: %w", err)
|
||
|
}
|
||
|
|
||
|
// If TestDial is set, we are running in test mode and we should not
|
||
|
// find rule because header will mismatch.
|
||
|
if conn.TestDial == nil {
|
||
|
// Check if the rule already exists.
|
||
|
rule, err := findRule(conn, loopBackRule)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("find rule: %w", err)
|
||
|
}
|
||
|
if rule != nil {
|
||
|
// Rule already exists, no need to insert.
|
||
|
return nil
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// This inserts the rule to the top of the chain
|
||
|
_ = conn.InsertRule(loopBackRule)
|
||
|
|
||
|
if err = conn.Flush(); err != nil {
|
||
|
return fmt.Errorf("insert rule: %w", err)
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// getNFTByAddr returns the nftables with correct IP family
|
||
|
// that we will be using for the given address.
|
||
|
func (n *nftablesRunner) getNFTByAddr(addr netip.Addr) *nftable {
|
||
|
if addr.Is6() {
|
||
|
return n.nft6
|
||
|
}
|
||
|
return n.nft4
|
||
|
}
|
||
|
|
||
|
// AddLoopbackRule adds an nftables rule to permit loopback traffic to
|
||
|
// a local Tailscale IP. This rule is added only if it does not already exist.
|
||
|
func (n *nftablesRunner) AddLoopbackRule(addr netip.Addr) error {
|
||
|
nf := n.getNFTByAddr(addr)
|
||
|
|
||
|
inputChain, err := getChainFromTable(n.conn, nf.Filter, chainNameInput)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("get input chain: %w", err)
|
||
|
}
|
||
|
|
||
|
if err := insertLoopbackRule(n.conn, nf.Proto, nf.Filter, inputChain, addr); err != nil {
|
||
|
return fmt.Errorf("add loopback rule: %w", err)
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// DelLoopbackRule removes the nftables rule permitting loopback
|
||
|
// traffic to a Tailscale IP.
|
||
|
func (n *nftablesRunner) DelLoopbackRule(addr netip.Addr) error {
|
||
|
nf := n.getNFTByAddr(addr)
|
||
|
|
||
|
inputChain, err := getChainFromTable(n.conn, nf.Filter, chainNameInput)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("get input chain: %w", err)
|
||
|
}
|
||
|
|
||
|
loopBackRule, err := createLoopbackRule(nf.Proto, nf.Filter, inputChain, addr)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("create loopback rule: %w", err)
|
||
|
}
|
||
|
|
||
|
existingLoopBackRule, err := findRule(n.conn, loopBackRule)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("find loop back rule: %w", err)
|
||
|
}
|
||
|
if existingLoopBackRule == nil {
|
||
|
// Rule does not exist, no need to delete.
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
if err := n.conn.DelRule(existingLoopBackRule); err != nil {
|
||
|
return fmt.Errorf("delete rule: %w", err)
|
||
|
}
|
||
|
|
||
|
return n.conn.Flush()
|
||
|
}
|
||
|
|
||
|
// getTables gets the available nftable in nftables runner.
|
||
|
func (n *nftablesRunner) getTables() []*nftable {
|
||
|
if n.v6Available {
|
||
|
return []*nftable{n.nft4, n.nft6}
|
||
|
}
|
||
|
return []*nftable{n.nft4}
|
||
|
}
|
||
|
|
||
|
// getNATTables gets the available nftable in nftables runner.
|
||
|
// If the system does not support IPv6 NAT, only the IPv4 nftable
|
||
|
// will be returned.
|
||
|
func (n *nftablesRunner) getNATTables() []*nftable {
|
||
|
if n.v6NATAvailable {
|
||
|
return n.getTables()
|
||
|
}
|
||
|
return []*nftable{n.nft4}
|
||
|
}
|
||
|
|
||
|
// AddChains creates custom Tailscale chains in netfilter via nftables
|
||
|
// if the ts-chain doesn't already exist.
|
||
|
func (n *nftablesRunner) AddChains() error {
|
||
|
for _, table := range n.getTables() {
|
||
|
filter, err := createTableIfNotExist(n.conn, table.Proto, "ts-filter")
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("create table: %w", err)
|
||
|
}
|
||
|
table.Filter = filter
|
||
|
if err = createChainIfNotExist(n.conn, chainInfo{filter, chainNameForward, nftables.ChainTypeFilter, nftables.ChainHookForward, nftables.ChainPriorityRef(-1)}); err != nil {
|
||
|
return fmt.Errorf("create forward chain: %w", err)
|
||
|
}
|
||
|
if err = createChainIfNotExist(n.conn, chainInfo{filter, chainNameInput, nftables.ChainTypeFilter, nftables.ChainHookInput, nftables.ChainPriorityRef(-1)}); err != nil {
|
||
|
return fmt.Errorf("create input chain: %w", err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
for _, table := range n.getNATTables() {
|
||
|
nat, err := createTableIfNotExist(n.conn, table.Proto, "ts-nat")
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("create table: %w", err)
|
||
|
}
|
||
|
table.Nat = nat
|
||
|
if err = createChainIfNotExist(n.conn, chainInfo{nat, chainNamePostrouting, nftables.ChainTypeNAT, nftables.ChainHookPostrouting, nftables.ChainPriorityNATDest}); err != nil {
|
||
|
return fmt.Errorf("create postrouting chain: %w", err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return n.conn.Flush()
|
||
|
}
|
||
|
|
||
|
// deleteChainIfExists deletes a chain if it exists.
|
||
|
func deleteChainIfExists(c *nftables.Conn, table *nftables.Table, name string) error {
|
||
|
chain, err := getChainFromTable(c, table, name)
|
||
|
if err != nil && !errors.Is(err, errorChainNotFound{table.Name, name}) {
|
||
|
return fmt.Errorf("get chain: %w", err)
|
||
|
} else if err != nil {
|
||
|
// If the chain doesn't exist, we don't need to delete it.
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
c.FlushChain(chain)
|
||
|
c.DelChain(chain)
|
||
|
|
||
|
if err := c.Flush(); err != nil {
|
||
|
return fmt.Errorf("flush and delete chain: %w", err)
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// DelChains removes the custom Tailscale chains from netfilter via nftables.
|
||
|
func (n *nftablesRunner) DelChains() error {
|
||
|
for _, table := range n.getTables() {
|
||
|
if err := deleteChainIfExists(n.conn, table.Filter, chainNameForward); err != nil {
|
||
|
return fmt.Errorf("delete chain: %w", err)
|
||
|
}
|
||
|
if err := deleteChainIfExists(n.conn, table.Filter, chainNameInput); err != nil {
|
||
|
return fmt.Errorf("delete chain: %w", err)
|
||
|
}
|
||
|
n.conn.DelTable(table.Filter)
|
||
|
}
|
||
|
|
||
|
if err := deleteChainIfExists(n.conn, n.nft4.Nat, chainNamePostrouting); err != nil {
|
||
|
return fmt.Errorf("delete chain: %w", err)
|
||
|
}
|
||
|
n.conn.DelTable(n.nft4.Nat)
|
||
|
|
||
|
if n.v6NATAvailable {
|
||
|
if err := deleteChainIfExists(n.conn, n.nft6.Nat, chainNamePostrouting); err != nil {
|
||
|
return fmt.Errorf("delete chain: %w", err)
|
||
|
}
|
||
|
n.conn.DelTable(n.nft6.Nat)
|
||
|
}
|
||
|
|
||
|
if err := n.conn.Flush(); err != nil {
|
||
|
return fmt.Errorf("flush: %w", err)
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// AddHooks is defined to satisfy the interface. NfTables does not require
|
||
|
// AddHooks, since we don't have any default tables or chains in nftables.
|
||
|
func (n *nftablesRunner) AddHooks() error {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// DelHooks is defined to satisfy the interface. NfTables does not require
|
||
|
// DelHooks, since we don't have any default tables or chains in nftables.
|
||
|
func (n *nftablesRunner) DelHooks(logf logger.Logf) error {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// maskof returns the mask of the given prefix in big endian bytes.
|
||
|
func maskof(pfx netip.Prefix) []byte {
|
||
|
mask := make([]byte, 4)
|
||
|
binary.BigEndian.PutUint32(mask, ^(uint32(0xffff_ffff) >> pfx.Bits()))
|
||
|
return mask
|
||
|
}
|
||
|
|
||
|
// createRangeRule creates a rule that matches packets with source IP from the give
|
||
|
// range (like CGNAT range or ChromeOSVM range) and the interface is not the tunname,
|
||
|
// and makes the given decision. Only IPv4 is supported.
|
||
|
func createRangeRule(
|
||
|
table *nftables.Table, chain *nftables.Chain,
|
||
|
tunname string, rng netip.Prefix, decision expr.VerdictKind,
|
||
|
) (*nftables.Rule, error) {
|
||
|
if rng.Addr().Is6() {
|
||
|
return nil, errors.New("IPv6 is not supported")
|
||
|
}
|
||
|
saddrExpr, err := newLoadSaddrExpr(nftables.TableFamilyIPv4, 1)
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("newLoadSaddrExpr: %w", err)
|
||
|
}
|
||
|
netip := rng.Addr().AsSlice()
|
||
|
mask := maskof(rng)
|
||
|
rule := &nftables.Rule{
|
||
|
Table: table,
|
||
|
Chain: chain,
|
||
|
Exprs: []expr.Any{
|
||
|
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||
|
&expr.Cmp{
|
||
|
Op: expr.CmpOpNeq,
|
||
|
Register: 1,
|
||
|
Data: []byte(tunname),
|
||
|
},
|
||
|
saddrExpr,
|
||
|
&expr.Bitwise{
|
||
|
SourceRegister: 1,
|
||
|
DestRegister: 1,
|
||
|
Len: 4,
|
||
|
Mask: mask,
|
||
|
Xor: []byte{0x00, 0x00, 0x00, 0x00},
|
||
|
},
|
||
|
&expr.Cmp{
|
||
|
Op: expr.CmpOpEq,
|
||
|
Register: 1,
|
||
|
Data: netip,
|
||
|
},
|
||
|
&expr.Counter{},
|
||
|
&expr.Verdict{
|
||
|
Kind: decision,
|
||
|
},
|
||
|
},
|
||
|
}
|
||
|
return rule, nil
|
||
|
|
||
|
}
|
||
|
|
||
|
// addReturnChromeOSVMRangeRule adds a rule to return if the source IP
|
||
|
// is in the ChromeOS VM range.
|
||
|
func addReturnChromeOSVMRangeRule(c *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error {
|
||
|
rule, err := createRangeRule(table, chain, tunname, tsaddr.ChromeOSVMRange(), expr.VerdictReturn)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("create rule: %w", err)
|
||
|
}
|
||
|
_ = c.AddRule(rule)
|
||
|
if err = c.Flush(); err != nil {
|
||
|
return fmt.Errorf("add rule: %w", err)
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// addDropCGNATRangeRule adds a rule to drop if the source IP is in the
|
||
|
// CGNAT range.
|
||
|
func addDropCGNATRangeRule(c *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error {
|
||
|
rule, err := createRangeRule(table, chain, tunname, tsaddr.CGNATRange(), expr.VerdictDrop)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("create rule: %w", err)
|
||
|
}
|
||
|
_ = c.AddRule(rule)
|
||
|
if err = c.Flush(); err != nil {
|
||
|
return fmt.Errorf("add rule: %w", err)
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// createSetSubnetRouteMarkRule creates a rule to set the subnet route
|
||
|
// mark if the packet is from the given interface.
|
||
|
func createSetSubnetRouteMarkRule(table *nftables.Table, chain *nftables.Chain, tunname string) (*nftables.Rule, error) {
|
||
|
hexTsFwmarkMaskNeg := getTailscaleFwmarkMaskNeg()
|
||
|
hexTSSubnetRouteMark := getTailscaleSubnetRouteMark()
|
||
|
|
||
|
rule := &nftables.Rule{
|
||
|
Table: table,
|
||
|
Chain: chain,
|
||
|
Exprs: []expr.Any{
|
||
|
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||
|
&expr.Cmp{
|
||
|
Op: expr.CmpOpEq,
|
||
|
Register: 1,
|
||
|
Data: []byte(tunname),
|
||
|
},
|
||
|
&expr.Counter{},
|
||
|
&expr.Meta{Key: expr.MetaKeyMARK, Register: 1},
|
||
|
&expr.Bitwise{
|
||
|
SourceRegister: 1,
|
||
|
DestRegister: 1,
|
||
|
Len: 4,
|
||
|
Mask: hexTsFwmarkMaskNeg,
|
||
|
Xor: hexTSSubnetRouteMark,
|
||
|
},
|
||
|
&expr.Meta{
|
||
|
Key: expr.MetaKeyMARK,
|
||
|
SourceRegister: true,
|
||
|
Register: 1,
|
||
|
},
|
||
|
},
|
||
|
}
|
||
|
return rule, nil
|
||
|
}
|
||
|
|
||
|
// addSetSubnetRouteMarkRule adds a rule to set the subnet route mark
|
||
|
// if the packet is from the given interface.
|
||
|
func addSetSubnetRouteMarkRule(c *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error {
|
||
|
rule, err := createSetSubnetRouteMarkRule(table, chain, tunname)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("create rule: %w", err)
|
||
|
}
|
||
|
_ = c.AddRule(rule)
|
||
|
|
||
|
if err := c.Flush(); err != nil {
|
||
|
return fmt.Errorf("add rule: %w", err)
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// createDropOutgoingPacketFromCGNATRangeRuleWithTunname creates a rule to drop
|
||
|
// outgoing packets from the CGNAT range.
|
||
|
func createDropOutgoingPacketFromCGNATRangeRuleWithTunname(table *nftables.Table, chain *nftables.Chain, tunname string) (*nftables.Rule, error) {
|
||
|
_, ipNet, err := net.ParseCIDR(tsaddr.CGNATRange().String())
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("parse cidr: %v", err)
|
||
|
}
|
||
|
mask, err := hex.DecodeString(ipNet.Mask.String())
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("decode mask: %v", err)
|
||
|
}
|
||
|
netip := ipNet.IP.Mask(ipNet.Mask).To4()
|
||
|
saddrExpr, err := newLoadSaddrExpr(nftables.TableFamilyIPv4, 1)
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("newLoadSaddrExpr: %v", err)
|
||
|
}
|
||
|
rule := &nftables.Rule{
|
||
|
Table: table,
|
||
|
Chain: chain,
|
||
|
Exprs: []expr.Any{
|
||
|
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||
|
&expr.Cmp{
|
||
|
Op: expr.CmpOpEq,
|
||
|
Register: 1,
|
||
|
Data: []byte(tunname),
|
||
|
},
|
||
|
saddrExpr,
|
||
|
&expr.Bitwise{
|
||
|
SourceRegister: 1,
|
||
|
DestRegister: 1,
|
||
|
Len: 4,
|
||
|
Mask: mask,
|
||
|
Xor: []byte{0x00, 0x00, 0x00, 0x00},
|
||
|
},
|
||
|
&expr.Cmp{
|
||
|
Op: expr.CmpOpEq,
|
||
|
Register: 1,
|
||
|
Data: netip,
|
||
|
},
|
||
|
&expr.Counter{},
|
||
|
&expr.Verdict{
|
||
|
Kind: expr.VerdictDrop,
|
||
|
},
|
||
|
},
|
||
|
}
|
||
|
return rule, nil
|
||
|
}
|
||
|
|
||
|
// addDropOutgoingPacketFromCGNATRangeRuleWithTunname adds a rule to drop
|
||
|
// outgoing packets from the CGNAT range.
|
||
|
func addDropOutgoingPacketFromCGNATRangeRuleWithTunname(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error {
|
||
|
rule, err := createDropOutgoingPacketFromCGNATRangeRuleWithTunname(table, chain, tunname)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("create rule: %w", err)
|
||
|
}
|
||
|
_ = conn.AddRule(rule)
|
||
|
|
||
|
if err := conn.Flush(); err != nil {
|
||
|
return fmt.Errorf("add rule: %w", err)
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// createAcceptOutgoingPacketRule creates a rule to accept outgoing packets
|
||
|
// from the given interface.
|
||
|
func createAcceptOutgoingPacketRule(table *nftables.Table, chain *nftables.Chain, tunname string) *nftables.Rule {
|
||
|
return &nftables.Rule{
|
||
|
Table: table,
|
||
|
Chain: chain,
|
||
|
Exprs: []expr.Any{
|
||
|
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||
|
&expr.Cmp{
|
||
|
Op: expr.CmpOpEq,
|
||
|
Register: 1,
|
||
|
Data: []byte(tunname),
|
||
|
},
|
||
|
&expr.Counter{},
|
||
|
&expr.Verdict{
|
||
|
Kind: expr.VerdictAccept,
|
||
|
},
|
||
|
},
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// addAcceptOutgoingPacketRule adds a rule to accept outgoing packets
|
||
|
// from the given interface.
|
||
|
func addAcceptOutgoingPacketRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error {
|
||
|
rule := createAcceptOutgoingPacketRule(table, chain, tunname)
|
||
|
_ = conn.AddRule(rule)
|
||
|
|
||
|
if err := conn.Flush(); err != nil {
|
||
|
return fmt.Errorf("flush add rule: %w", err)
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// AddBase adds some basic processing rules.
|
||
|
func (n *nftablesRunner) AddBase(tunname string) error {
|
||
|
if err := n.addBase4(tunname); err != nil {
|
||
|
return fmt.Errorf("add base v4: %w", err)
|
||
|
}
|
||
|
if n.HasIPV6() {
|
||
|
if err := n.addBase6(tunname); err != nil {
|
||
|
return fmt.Errorf("add base v6: %w", err)
|
||
|
}
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// addBase4 adds some basic IPv4 processing rules.
|
||
|
func (n *nftablesRunner) addBase4(tunname string) error {
|
||
|
conn := n.conn
|
||
|
|
||
|
inputChain, err := getChainFromTable(conn, n.nft4.Filter, chainNameInput)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("get input chain v4: %v", err)
|
||
|
}
|
||
|
if err = addReturnChromeOSVMRangeRule(conn, n.nft4.Filter, inputChain, tunname); err != nil {
|
||
|
return fmt.Errorf("add return chromeos vm range rule v4: %w", err)
|
||
|
}
|
||
|
if err = addDropCGNATRangeRule(conn, n.nft4.Filter, inputChain, tunname); err != nil {
|
||
|
return fmt.Errorf("add drop cgnat range rule v4: %w", err)
|
||
|
}
|
||
|
|
||
|
forwardChain, err := getChainFromTable(conn, n.nft4.Filter, chainNameForward)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("get forward chain v4: %v", err)
|
||
|
}
|
||
|
|
||
|
if err = addSetSubnetRouteMarkRule(conn, n.nft4.Filter, forwardChain, tunname); err != nil {
|
||
|
return fmt.Errorf("add set subnet route mark rule v4: %w", err)
|
||
|
}
|
||
|
|
||
|
if err = addMatchSubnetRouteMarkRule(conn, n.nft4.Filter, forwardChain, Accept); err != nil {
|
||
|
return fmt.Errorf("add match subnet route mark rule v4: %w", err)
|
||
|
}
|
||
|
|
||
|
if err = addDropOutgoingPacketFromCGNATRangeRuleWithTunname(conn, n.nft4.Filter, forwardChain, tunname); err != nil {
|
||
|
return fmt.Errorf("add drop outgoing packet from cgnat range rule v4: %w", err)
|
||
|
}
|
||
|
|
||
|
if err = addAcceptOutgoingPacketRule(conn, n.nft4.Filter, forwardChain, tunname); err != nil {
|
||
|
return fmt.Errorf("add accept outgoing packet rule v4: %w", err)
|
||
|
}
|
||
|
|
||
|
if err = conn.Flush(); err != nil {
|
||
|
return fmt.Errorf("flush base v4: %w", err)
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// addBase6 adds some basic IPv6 processing rules.
|
||
|
func (n *nftablesRunner) addBase6(tunname string) error {
|
||
|
conn := n.conn
|
||
|
|
||
|
forwardChain, err := getChainFromTable(conn, n.nft6.Filter, chainNameForward)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("get forward chain v6: %w", err)
|
||
|
}
|
||
|
|
||
|
if err = addSetSubnetRouteMarkRule(conn, n.nft6.Filter, forwardChain, tunname); err != nil {
|
||
|
return fmt.Errorf("add set subnet route mark rule v6: %w", err)
|
||
|
}
|
||
|
|
||
|
if err = addMatchSubnetRouteMarkRule(conn, n.nft6.Filter, forwardChain, Accept); err != nil {
|
||
|
return fmt.Errorf("add match subnet route mark rule v6: %w", err)
|
||
|
}
|
||
|
|
||
|
if err = addAcceptOutgoingPacketRule(conn, n.nft6.Filter, forwardChain, tunname); err != nil {
|
||
|
return fmt.Errorf("add accept outgoing packet rule v6: %w", err)
|
||
|
}
|
||
|
|
||
|
if err = conn.Flush(); err != nil {
|
||
|
return fmt.Errorf("flush base v6: %w", err)
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// DelBase empties, but does not remove, custom Tailscale chains from
|
||
|
// netfilter via iptables.
|
||
|
func (n *nftablesRunner) DelBase() error {
|
||
|
conn := n.conn
|
||
|
|
||
|
for _, table := range n.getTables() {
|
||
|
inputChain, err := getChainFromTable(conn, table.Filter, chainNameInput)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("get input chain: %v", err)
|
||
|
}
|
||
|
conn.FlushChain(inputChain)
|
||
|
forwardChain, err := getChainFromTable(conn, table.Filter, chainNameForward)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("get forward chain: %v", err)
|
||
|
}
|
||
|
conn.FlushChain(forwardChain)
|
||
|
}
|
||
|
|
||
|
for _, table := range n.getNATTables() {
|
||
|
postrouteChain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("get postrouting chain v4: %v", err)
|
||
|
}
|
||
|
conn.FlushChain(postrouteChain)
|
||
|
}
|
||
|
|
||
|
return conn.Flush()
|
||
|
}
|
||
|
|
||
|
// createMatchSubnetRouteMarkRule creates a rule that matches packets
|
||
|
// with the subnet route mark and takes the specified action.
|
||
|
func createMatchSubnetRouteMarkRule(table *nftables.Table, chain *nftables.Chain, action MatchDecision) (*nftables.Rule, error) {
|
||
|
hexTSFwmarkMask := getTailscaleFwmarkMask()
|
||
|
hexTSSubnetRouteMark := getTailscaleSubnetRouteMark()
|
||
|
|
||
|
var endAction expr.Any
|
||
|
endAction = &expr.Verdict{Kind: expr.VerdictAccept}
|
||
|
if action == Masq {
|
||
|
endAction = &expr.Masq{}
|
||
|
}
|
||
|
|
||
|
exprs := []expr.Any{
|
||
|
&expr.Meta{Key: expr.MetaKeyMARK, Register: 1},
|
||
|
&expr.Bitwise{
|
||
|
SourceRegister: 1,
|
||
|
DestRegister: 1,
|
||
|
Len: 4,
|
||
|
Mask: hexTSFwmarkMask,
|
||
|
Xor: []byte{0x00, 0x00, 0x00, 0x00},
|
||
|
},
|
||
|
&expr.Cmp{
|
||
|
Op: expr.CmpOpEq,
|
||
|
Register: 1,
|
||
|
Data: hexTSSubnetRouteMark,
|
||
|
},
|
||
|
&expr.Counter{},
|
||
|
endAction,
|
||
|
}
|
||
|
|
||
|
rule := &nftables.Rule{
|
||
|
Table: table,
|
||
|
Chain: chain,
|
||
|
Exprs: exprs,
|
||
|
}
|
||
|
return rule, nil
|
||
|
}
|
||
|
|
||
|
// addMatchSubnetRouteMarkRule adds a rule that matches packets with
|
||
|
// the subnet route mark and takes the specified action.
|
||
|
func addMatchSubnetRouteMarkRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, action MatchDecision) error {
|
||
|
rule, err := createMatchSubnetRouteMarkRule(table, chain, action)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("create match subnet route mark rule: %w", err)
|
||
|
}
|
||
|
_ = conn.AddRule(rule)
|
||
|
|
||
|
if err := conn.Flush(); err != nil {
|
||
|
return fmt.Errorf("flush add rule: %w", err)
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// AddSNATRule adds a netfilter rule to SNAT traffic destined for
|
||
|
// local subnets.
|
||
|
func (n *nftablesRunner) AddSNATRule() error {
|
||
|
conn := n.conn
|
||
|
|
||
|
for _, table := range n.getNATTables() {
|
||
|
chain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("get postrouting chain v4: %w", err)
|
||
|
}
|
||
|
|
||
|
if err = addMatchSubnetRouteMarkRule(conn, table.Nat, chain, Masq); err != nil {
|
||
|
return fmt.Errorf("add match subnet route mark rule v4: %w", err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if err := conn.Flush(); err != nil {
|
||
|
return fmt.Errorf("flush add SNAT rule: %w", err)
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// DelSNATRule removes the netfilter rule to SNAT traffic destined for
|
||
|
// local subnets. An error is returned if the rule does not exist.
|
||
|
func (n *nftablesRunner) DelSNATRule() error {
|
||
|
conn := n.conn
|
||
|
|
||
|
hexTSFwmarkMask := getTailscaleFwmarkMask()
|
||
|
hexTSSubnetRouteMark := getTailscaleSubnetRouteMark()
|
||
|
|
||
|
exprs := []expr.Any{
|
||
|
&expr.Meta{Key: expr.MetaKeyMARK, Register: 1},
|
||
|
&expr.Bitwise{
|
||
|
SourceRegister: 1,
|
||
|
DestRegister: 1,
|
||
|
Len: 4,
|
||
|
Mask: hexTSFwmarkMask,
|
||
|
},
|
||
|
&expr.Cmp{
|
||
|
Op: expr.CmpOpEq,
|
||
|
Register: 1,
|
||
|
Data: hexTSSubnetRouteMark,
|
||
|
},
|
||
|
&expr.Counter{},
|
||
|
&expr.Masq{},
|
||
|
}
|
||
|
|
||
|
for _, table := range n.getNATTables() {
|
||
|
chain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("get postrouting chain v4: %w", err)
|
||
|
}
|
||
|
|
||
|
rule := &nftables.Rule{
|
||
|
Table: table.Nat,
|
||
|
Chain: chain,
|
||
|
Exprs: exprs,
|
||
|
}
|
||
|
|
||
|
SNATRule, err := findRule(conn, rule)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("find SNAT rule v4: %w", err)
|
||
|
}
|
||
|
|
||
|
_ = conn.DelRule(SNATRule)
|
||
|
}
|
||
|
|
||
|
if err := conn.Flush(); err != nil {
|
||
|
return fmt.Errorf("flush del SNAT rule: %w", err)
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// NfTablesCleanUp removes all Tailscale added nftables rules.
|
||
|
// Any errors that occur are logged to the provided logf.
|
||
|
func NfTablesCleanUp(logf logger.Logf) {
|
||
|
conn, err := nftables.New()
|
||
|
if err != nil {
|
||
|
logf("ERROR: nftables connection: %w", err)
|
||
|
}
|
||
|
|
||
|
tables, err := conn.ListTables() // both v4 and v6
|
||
|
if err != nil {
|
||
|
logf("ERROR: list tables: %w", err)
|
||
|
}
|
||
|
|
||
|
for _, table := range tables {
|
||
|
if table.Name == "ts-filter" || table.Name == "ts-nat" {
|
||
|
conn.DelTable(table)
|
||
|
if err := conn.Flush(); err != nil {
|
||
|
logf("ERROR: flush table %s: %w", table.Name, err)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|