mirror of
https://github.com/tailscale/tailscale.git
synced 2025-04-21 06:01:42 +00:00
util/linuxfw: reorganize nftables rules to allow it to work with ufw
This commit tries to mimic the way iptables-nft work with the filewall rules. We follow the convention of using tables like filter, nat and the conventional chains, to make our nftables implementation work with ufw. Updates: #391 Signed-off-by: KevinLiang10 <kevinliang@tailscale.com>
This commit is contained in:
parent
d4586ca75f
commit
b040094b90
@ -13,6 +13,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/google/nftables"
|
"github.com/google/nftables"
|
||||||
"github.com/google/nftables/expr"
|
"github.com/google/nftables/expr"
|
||||||
@ -26,12 +27,16 @@ const (
|
|||||||
chainNamePostrouting = "ts-postrouting"
|
chainNamePostrouting = "ts-postrouting"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// chainTypeRegular is an nftables chain that does not apply to a hook.
|
||||||
|
const chainTypeRegular = ""
|
||||||
|
|
||||||
type chainInfo struct {
|
type chainInfo struct {
|
||||||
table *nftables.Table
|
table *nftables.Table
|
||||||
name string
|
name string
|
||||||
chainType nftables.ChainType
|
chainType nftables.ChainType
|
||||||
chainHook *nftables.ChainHook
|
chainHook *nftables.ChainHook
|
||||||
chainPriority *nftables.ChainPriority
|
chainPriority *nftables.ChainPriority
|
||||||
|
chainPolicy *nftables.ChainPolicy
|
||||||
}
|
}
|
||||||
|
|
||||||
type nftable struct {
|
type nftable struct {
|
||||||
@ -40,6 +45,21 @@ type nftable struct {
|
|||||||
Nat *nftables.Table
|
Nat *nftables.Table
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// nftablesRunner implements a netfilterRunner using the netlink based nftables
|
||||||
|
// library. As nftables allows for arbitrary tables and chains, there is a need
|
||||||
|
// to follow conventions in order to integrate well with a surrounding
|
||||||
|
// ecosystem. The rules installed by nftablesRunner have the following
|
||||||
|
// properties:
|
||||||
|
// - Install rules that intend to take precedence over rules installed by
|
||||||
|
// other software. Tailscale provides packet filtering for tailnet traffic
|
||||||
|
// inside the daemon based on the tailnet ACL rules.
|
||||||
|
// - As nftables "accept" is not final, rules from high priority tables (low
|
||||||
|
// numbers) will fall through to lower priority tables (high numbers). In
|
||||||
|
// order to effectively be 'final', we install "jump" rules into conventional
|
||||||
|
// tables and chains that will reach an accept verdict inside those tables.
|
||||||
|
// - The table and chain conventions followed here are those used by
|
||||||
|
// `iptables-nft` and `ufw`, so that those tools co-exist and do not
|
||||||
|
// negatively affect Tailscale function.
|
||||||
type nftablesRunner struct {
|
type nftablesRunner struct {
|
||||||
conn *nftables.Conn
|
conn *nftables.Conn
|
||||||
nft4 *nftable
|
nft4 *nftable
|
||||||
@ -116,6 +136,11 @@ func getChainsFromTable(c *nftables.Conn, table *nftables.Table) ([]*nftables.Ch
|
|||||||
return ret, nil
|
return ret, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isTSChain retruns true if the chain name starts with ts
|
||||||
|
func isTSChain(name string) bool {
|
||||||
|
return strings.HasPrefix(name, "ts-")
|
||||||
|
}
|
||||||
|
|
||||||
// createChainIfNotExist creates a chain with the given name in the given table
|
// createChainIfNotExist creates a chain with the given name in the given table
|
||||||
// if it does not exist.
|
// if it does not exist.
|
||||||
func createChainIfNotExist(c *nftables.Conn, cinfo chainInfo) error {
|
func createChainIfNotExist(c *nftables.Conn, cinfo chainInfo) error {
|
||||||
@ -123,8 +148,11 @@ func createChainIfNotExist(c *nftables.Conn, cinfo chainInfo) error {
|
|||||||
if err != nil && !errors.Is(err, errorChainNotFound{cinfo.table.Name, cinfo.name}) {
|
if err != nil && !errors.Is(err, errorChainNotFound{cinfo.table.Name, cinfo.name}) {
|
||||||
return fmt.Errorf("get chain: %w", err)
|
return fmt.Errorf("get chain: %w", err)
|
||||||
} else if err == nil {
|
} else if err == nil {
|
||||||
// Chain already exists
|
// The chain already exists. If it is a TS chain, check the
|
||||||
if chain.Type != cinfo.chainType || chain.Hooknum != cinfo.chainHook || chain.Priority != cinfo.chainPriority {
|
// type/hook/priority, but for "conventional chains" assume they're what
|
||||||
|
// we expect (in case iptables-nft/ufw make minor behavior changes in
|
||||||
|
// the future).
|
||||||
|
if isTSChain(chain.Name) && (chain.Type != cinfo.chainType || chain.Hooknum != cinfo.chainHook || chain.Priority != cinfo.chainPriority) {
|
||||||
return fmt.Errorf("chain %s already exists with different type/hook/priority", cinfo.name)
|
return fmt.Errorf("chain %s already exists with different type/hook/priority", cinfo.name)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@ -136,6 +164,7 @@ func createChainIfNotExist(c *nftables.Conn, cinfo chainInfo) error {
|
|||||||
Type: cinfo.chainType,
|
Type: cinfo.chainType,
|
||||||
Hooknum: cinfo.chainHook,
|
Hooknum: cinfo.chainHook,
|
||||||
Priority: cinfo.chainPriority,
|
Priority: cinfo.chainPriority,
|
||||||
|
Policy: cinfo.chainPolicy,
|
||||||
})
|
})
|
||||||
|
|
||||||
if err := c.Flush(); err != nil {
|
if err := c.Flush(); err != nil {
|
||||||
@ -228,6 +257,10 @@ ruleLoop:
|
|||||||
}
|
}
|
||||||
|
|
||||||
for i, e := range r.Exprs {
|
for i, e := range r.Exprs {
|
||||||
|
// Skip counter expressions, as they will not match.
|
||||||
|
if _, ok := e.(*expr.Counter); ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
if !reflect.DeepEqual(e, rule.Exprs[i]) {
|
if !reflect.DeepEqual(e, rule.Exprs[i]) {
|
||||||
continue ruleLoop
|
continue ruleLoop
|
||||||
}
|
}
|
||||||
@ -388,27 +421,49 @@ func (n *nftablesRunner) getNATTables() []*nftable {
|
|||||||
// AddChains creates custom Tailscale chains in netfilter via nftables
|
// AddChains creates custom Tailscale chains in netfilter via nftables
|
||||||
// if the ts-chain doesn't already exist.
|
// if the ts-chain doesn't already exist.
|
||||||
func (n *nftablesRunner) AddChains() error {
|
func (n *nftablesRunner) AddChains() error {
|
||||||
|
polAccept := nftables.ChainPolicyAccept
|
||||||
for _, table := range n.getTables() {
|
for _, table := range n.getTables() {
|
||||||
filter, err := createTableIfNotExist(n.conn, table.Proto, "ts-filter")
|
// Create the filter table if it doesn't exist, this table name is the same
|
||||||
|
// as the name used by iptables-nft and ufw. We install rules into the
|
||||||
|
// same conventional table so that `accept` verdicts from our jump
|
||||||
|
// chains are conclusive.
|
||||||
|
filter, err := createTableIfNotExist(n.conn, table.Proto, "filter")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("create table: %w", err)
|
return fmt.Errorf("create table: %w", err)
|
||||||
}
|
}
|
||||||
table.Filter = filter
|
table.Filter = filter
|
||||||
if err = createChainIfNotExist(n.conn, chainInfo{filter, chainNameForward, nftables.ChainTypeFilter, nftables.ChainHookForward, nftables.ChainPriorityRef(-1)}); err != nil {
|
// Adding the "conventional chains" that are used by iptables-nft and ufw.
|
||||||
|
if err = createChainIfNotExist(n.conn, chainInfo{filter, "FORWARD", nftables.ChainTypeFilter, nftables.ChainHookForward, nftables.ChainPriorityFilter, &polAccept}); err != nil {
|
||||||
return fmt.Errorf("create forward chain: %w", err)
|
return fmt.Errorf("create forward chain: %w", err)
|
||||||
}
|
}
|
||||||
if err = createChainIfNotExist(n.conn, chainInfo{filter, chainNameInput, nftables.ChainTypeFilter, nftables.ChainHookInput, nftables.ChainPriorityRef(-1)}); err != nil {
|
if err = createChainIfNotExist(n.conn, chainInfo{filter, "INPUT", nftables.ChainTypeFilter, nftables.ChainHookInput, nftables.ChainPriorityFilter, &polAccept}); err != nil {
|
||||||
|
return fmt.Errorf("create input chain: %w", err)
|
||||||
|
}
|
||||||
|
// Adding the tailscale chains that contain our rules.
|
||||||
|
if err = createChainIfNotExist(n.conn, chainInfo{filter, chainNameForward, chainTypeRegular, nil, nil, nil}); err != nil {
|
||||||
|
return fmt.Errorf("create forward chain: %w", err)
|
||||||
|
}
|
||||||
|
if err = createChainIfNotExist(n.conn, chainInfo{filter, chainNameInput, chainTypeRegular, nil, nil, nil}); err != nil {
|
||||||
return fmt.Errorf("create input chain: %w", err)
|
return fmt.Errorf("create input chain: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, table := range n.getNATTables() {
|
for _, table := range n.getNATTables() {
|
||||||
nat, err := createTableIfNotExist(n.conn, table.Proto, "ts-nat")
|
// Create the nat table if it doesn't exist, this table name is the same
|
||||||
|
// as the name used by iptables-nft and ufw. We install rules into the
|
||||||
|
// same conventional table so that `accept` verdicts from our jump
|
||||||
|
// chains are conclusive.
|
||||||
|
nat, err := createTableIfNotExist(n.conn, table.Proto, "nat")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("create table: %w", err)
|
return fmt.Errorf("create table: %w", err)
|
||||||
}
|
}
|
||||||
table.Nat = nat
|
table.Nat = nat
|
||||||
if err = createChainIfNotExist(n.conn, chainInfo{nat, chainNamePostrouting, nftables.ChainTypeNAT, nftables.ChainHookPostrouting, nftables.ChainPriorityNATDest}); err != nil {
|
// Adding the "conventional chains" that are used by iptables-nft and ufw.
|
||||||
|
if err = createChainIfNotExist(n.conn, chainInfo{nat, "POSTROUTING", nftables.ChainTypeNAT, nftables.ChainHookPostrouting, nftables.ChainPriorityNATSource, &polAccept}); err != nil {
|
||||||
|
return fmt.Errorf("create postrouting chain: %w", err)
|
||||||
|
}
|
||||||
|
// Adding the tailscale chain that contains our rules.
|
||||||
|
if err = createChainIfNotExist(n.conn, chainInfo{nat, chainNamePostrouting, chainTypeRegular, nil, nil, nil}); err != nil {
|
||||||
return fmt.Errorf("create postrouting chain: %w", err)
|
return fmt.Errorf("create postrouting chain: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -445,19 +500,16 @@ func (n *nftablesRunner) DelChains() error {
|
|||||||
if err := deleteChainIfExists(n.conn, table.Filter, chainNameInput); err != nil {
|
if err := deleteChainIfExists(n.conn, table.Filter, chainNameInput); err != nil {
|
||||||
return fmt.Errorf("delete chain: %w", err)
|
return fmt.Errorf("delete chain: %w", err)
|
||||||
}
|
}
|
||||||
n.conn.DelTable(table.Filter)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := deleteChainIfExists(n.conn, n.nft4.Nat, chainNamePostrouting); err != nil {
|
if err := deleteChainIfExists(n.conn, n.nft4.Nat, chainNamePostrouting); err != nil {
|
||||||
return fmt.Errorf("delete chain: %w", err)
|
return fmt.Errorf("delete chain: %w", err)
|
||||||
}
|
}
|
||||||
n.conn.DelTable(n.nft4.Nat)
|
|
||||||
|
|
||||||
if n.v6NATAvailable {
|
if n.v6NATAvailable {
|
||||||
if err := deleteChainIfExists(n.conn, n.nft6.Nat, chainNamePostrouting); err != nil {
|
if err := deleteChainIfExists(n.conn, n.nft6.Nat, chainNamePostrouting); err != nil {
|
||||||
return fmt.Errorf("delete chain: %w", err)
|
return fmt.Errorf("delete chain: %w", err)
|
||||||
}
|
}
|
||||||
n.conn.DelTable(n.nft6.Nat)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := n.conn.Flush(); err != nil {
|
if err := n.conn.Flush(); err != nil {
|
||||||
@ -467,15 +519,128 @@ func (n *nftablesRunner) DelChains() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddHooks is defined to satisfy the interface. NfTables does not require
|
// createHookRule creates a rule to jump from a hooked chain to a regular chain.
|
||||||
// AddHooks, since we don't have any default tables or chains in nftables.
|
func createHookRule(table *nftables.Table, fromChain *nftables.Chain, toChainName string) *nftables.Rule {
|
||||||
func (n *nftablesRunner) AddHooks() error {
|
exprs := []expr.Any{
|
||||||
|
&expr.Counter{},
|
||||||
|
&expr.Verdict{
|
||||||
|
Kind: expr.VerdictJump,
|
||||||
|
Chain: toChainName,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
rule := &nftables.Rule{
|
||||||
|
Table: table,
|
||||||
|
Chain: fromChain,
|
||||||
|
Exprs: exprs,
|
||||||
|
}
|
||||||
|
|
||||||
|
return rule
|
||||||
|
}
|
||||||
|
|
||||||
|
// addHookRule adds a rule to jump from a hooked chain to a regular chain at top of the hooked chain.
|
||||||
|
func addHookRule(conn *nftables.Conn, table *nftables.Table, fromChain *nftables.Chain, toChainName string) error {
|
||||||
|
rule := createHookRule(table, fromChain, toChainName)
|
||||||
|
_ = conn.InsertRule(rule)
|
||||||
|
|
||||||
|
if err := conn.Flush(); err != nil {
|
||||||
|
return fmt.Errorf("flush add rule: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DelHooks is defined to satisfy the interface. NfTables does not require
|
// AddHooks is adding rules to conventional chains like "FORWARD", "INPUT" and "POSTROUTING"
|
||||||
// DelHooks, since we don't have any default tables or chains in nftables.
|
// in tables and jump from those chains to tailscale chains.
|
||||||
|
func (n *nftablesRunner) AddHooks() error {
|
||||||
|
conn := n.conn
|
||||||
|
|
||||||
|
for _, table := range n.getTables() {
|
||||||
|
inputChain, err := getChainFromTable(conn, table.Filter, "INPUT")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get INPUT chain: %w", err)
|
||||||
|
}
|
||||||
|
err = addHookRule(conn, table.Filter, inputChain, chainNameInput)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Addhook: %w", err)
|
||||||
|
}
|
||||||
|
forwardChain, err := getChainFromTable(conn, table.Filter, "FORWARD")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get FORWARD chain: %w", err)
|
||||||
|
}
|
||||||
|
err = addHookRule(conn, table.Filter, forwardChain, chainNameForward)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Addhook: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, table := range n.getNATTables() {
|
||||||
|
postroutingChain, err := getChainFromTable(conn, table.Nat, "POSTROUTING")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get INPUT chain: %w", err)
|
||||||
|
}
|
||||||
|
err = addHookRule(conn, table.Nat, postroutingChain, chainNamePostrouting)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Addhook: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// delHookRule deletes a rule that jumps from a hooked chain to a regular chain.
|
||||||
|
func delHookRule(conn *nftables.Conn, table *nftables.Table, fromChain *nftables.Chain, toChainName string) error {
|
||||||
|
rule := createHookRule(table, fromChain, toChainName)
|
||||||
|
existingRule, err := findRule(conn, rule)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Failed to find hook rule: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if existingRule == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = conn.DelRule(existingRule)
|
||||||
|
|
||||||
|
if err := conn.Flush(); err != nil {
|
||||||
|
return fmt.Errorf("flush del hook rule: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DelHooks is deleting the rules added to conventional chains to jump to tailscale chains.
|
||||||
func (n *nftablesRunner) DelHooks(logf logger.Logf) error {
|
func (n *nftablesRunner) DelHooks(logf logger.Logf) error {
|
||||||
|
conn := n.conn
|
||||||
|
|
||||||
|
for _, table := range n.getTables() {
|
||||||
|
inputChain, err := getChainFromTable(conn, table.Filter, "INPUT")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get INPUT chain: %w", err)
|
||||||
|
}
|
||||||
|
err = delHookRule(conn, table.Filter, inputChain, chainNameInput)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("delhook: %w", err)
|
||||||
|
}
|
||||||
|
forwardChain, err := getChainFromTable(conn, table.Filter, "FORWARD")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get FORWARD chain: %w", err)
|
||||||
|
}
|
||||||
|
err = delHookRule(conn, table.Filter, forwardChain, chainNameForward)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("delhook: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, table := range n.getNATTables() {
|
||||||
|
postroutingChain, err := getChainFromTable(conn, table.Nat, "POSTROUTING")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get INPUT chain: %w", err)
|
||||||
|
}
|
||||||
|
err = delHookRule(conn, table.Nat, postroutingChain, chainNamePostrouting)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("delhook: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -953,25 +1118,62 @@ func (n *nftablesRunner) DelSNATRule() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// cleanupChain removes a jump rule from hookChainName to tsChainName, and then
|
||||||
|
// the entire chain tsChainName. Errors are logged, but attempts to remove both
|
||||||
|
// the jump rule and chain continue even if one errors.
|
||||||
|
func cleanupChain(logf logger.Logf, conn *nftables.Conn, table *nftables.Table, hookChainName, tsChainName string) {
|
||||||
|
// remove the jump first, before removing the jump destination.
|
||||||
|
defaultChain, err := getChainFromTable(conn, table, hookChainName)
|
||||||
|
if err != nil && !errors.Is(err, errorChainNotFound{table.Name, hookChainName}) {
|
||||||
|
logf("cleanup: did not find default chain: %s", err)
|
||||||
|
}
|
||||||
|
if !errors.Is(err, errorChainNotFound{table.Name, hookChainName}) {
|
||||||
|
// delete hook in convention chain
|
||||||
|
_ = delHookRule(conn, table, defaultChain, tsChainName)
|
||||||
|
}
|
||||||
|
|
||||||
|
tsChain, err := getChainFromTable(conn, table, tsChainName)
|
||||||
|
if err != nil && !errors.Is(err, errorChainNotFound{table.Name, tsChainName}) {
|
||||||
|
logf("cleanup: did not find ts-chain: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tsChain != nil {
|
||||||
|
// flush and delete ts-chain
|
||||||
|
conn.FlushChain(tsChain)
|
||||||
|
conn.DelChain(tsChain)
|
||||||
|
err = conn.Flush()
|
||||||
|
logf("cleanup: delete and flush chain %s: %s", tsChainName, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// NfTablesCleanUp removes all Tailscale added nftables rules.
|
// NfTablesCleanUp removes all Tailscale added nftables rules.
|
||||||
// Any errors that occur are logged to the provided logf.
|
// Any errors that occur are logged to the provided logf.
|
||||||
func NfTablesCleanUp(logf logger.Logf) {
|
func NfTablesCleanUp(logf logger.Logf) {
|
||||||
conn, err := nftables.New()
|
conn, err := nftables.New()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logf("ERROR: nftables connection: %w", err)
|
logf("cleanup: nftables connection: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
tables, err := conn.ListTables() // both v4 and v6
|
tables, err := conn.ListTables() // both v4 and v6
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logf("ERROR: list tables: %w", err)
|
logf("cleanup: list tables: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, table := range tables {
|
for _, table := range tables {
|
||||||
|
// These table names were used briefly in 1.48.0.
|
||||||
if table.Name == "ts-filter" || table.Name == "ts-nat" {
|
if table.Name == "ts-filter" || table.Name == "ts-nat" {
|
||||||
conn.DelTable(table)
|
conn.DelTable(table)
|
||||||
if err := conn.Flush(); err != nil {
|
if err := conn.Flush(); err != nil {
|
||||||
logf("ERROR: flush table %s: %w", table.Name, err)
|
logf("cleanup: flush delete table %s: %s", table.Name, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if table.Name == "filter" {
|
||||||
|
cleanupChain(logf, conn, table, "INPUT", chainNameInput)
|
||||||
|
cleanupChain(logf, conn, table, "FORWARD", chainNameForward)
|
||||||
|
}
|
||||||
|
if table.Name == "nat" {
|
||||||
|
cleanupChain(logf, conn, table, "POSTROUTING", chainNamePostrouting)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -101,6 +101,48 @@ func newTestConn(t *testing.T, want [][]byte) *nftables.Conn {
|
|||||||
return conn
|
return conn
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestInsertHookRule(t *testing.T) {
|
||||||
|
proto := nftables.TableFamilyIPv4
|
||||||
|
want := [][]byte{
|
||||||
|
// batch begin
|
||||||
|
[]byte("\x00\x00\x00\x0a"),
|
||||||
|
// nft add table ip ts-filter-test
|
||||||
|
[]byte("\x02\x00\x00\x00\x13\x00\x01\x00\x74\x73\x2d\x66\x69\x6c\x74\x65\x72\x2d\x74\x65\x73\x74\x00\x00\x08\x00\x02\x00\x00\x00\x00\x00"),
|
||||||
|
// nft add chain ip ts-filter-test ts-input-test { type filter hook input priority 0 \; }
|
||||||
|
[]byte("\x02\x00\x00\x00\x13\x00\x01\x00\x74\x73\x2d\x66\x69\x6c\x74\x65\x72\x2d\x74\x65\x73\x74\x00\x00\x12\x00\x03\x00\x74\x73\x2d\x69\x6e\x70\x75\x74\x2d\x74\x65\x73\x74\x00\x00\x00\x14\x00\x04\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x00\x0b\x00\x07\x00\x66\x69\x6c\x74\x65\x72\x00\x00"),
|
||||||
|
// nft add chain ip ts-filter-test ts-jumpto
|
||||||
|
[]byte("\x02\x00\x00\x00\x13\x00\x01\x00\x74\x73\x2d\x66\x69\x6c\x74\x65\x72\x2d\x74\x65\x73\x74\x00\x00\x0e\x00\x03\x00\x74\x73\x2d\x6a\x75\x6d\x70\x74\x6f\x00\x00\x00"),
|
||||||
|
// nft add rule ip ts-filter-test ts-input-test counter jump ts-jumptp
|
||||||
|
[]byte("\x02\x00\x00\x00\x13\x00\x01\x00\x74\x73\x2d\x66\x69\x6c\x74\x65\x72\x2d\x74\x65\x73\x74\x00\x00\x12\x00\x02\x00\x74\x73\x2d\x69\x6e\x70\x75\x74\x2d\x74\x65\x73\x74\x00\x00\x00\x70\x00\x04\x80\x2c\x00\x01\x80\x0c\x00\x01\x00\x63\x6f\x75\x6e\x74\x65\x72\x00\x1c\x00\x02\x80\x0c\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0c\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x40\x00\x01\x80\x0e\x00\x01\x00\x69\x6d\x6d\x65\x64\x69\x61\x74\x65\x00\x00\x00\x2c\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x00\x20\x00\x02\x80\x1c\x00\x02\x80\x08\x00\x01\x00\xff\xff\xff\xfd\x0e\x00\x02\x00\x74\x73\x2d\x6a\x75\x6d\x70\x74\x6f\x00\x00\x00"),
|
||||||
|
// batch end
|
||||||
|
[]byte("\x00\x00\x00\x0a"),
|
||||||
|
}
|
||||||
|
testConn := newTestConn(t, want)
|
||||||
|
table := testConn.AddTable(&nftables.Table{
|
||||||
|
Family: proto,
|
||||||
|
Name: "ts-filter-test",
|
||||||
|
})
|
||||||
|
|
||||||
|
fromchain := testConn.AddChain(&nftables.Chain{
|
||||||
|
Name: "ts-input-test",
|
||||||
|
Table: table,
|
||||||
|
Type: nftables.ChainTypeFilter,
|
||||||
|
Hooknum: nftables.ChainHookInput,
|
||||||
|
Priority: nftables.ChainPriorityFilter,
|
||||||
|
})
|
||||||
|
|
||||||
|
tochain := testConn.AddChain(&nftables.Chain{
|
||||||
|
Name: "ts-jumpto",
|
||||||
|
Table: table,
|
||||||
|
})
|
||||||
|
|
||||||
|
err := addHookRule(testConn, table, fromchain, tochain.Name)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
func TestInsertLoopbackRule(t *testing.T) {
|
func TestInsertLoopbackRule(t *testing.T) {
|
||||||
proto := nftables.TableFamilyIPv4
|
proto := nftables.TableFamilyIPv4
|
||||||
want := [][]byte{
|
want := [][]byte{
|
||||||
@ -461,8 +503,8 @@ func TestAddAndDelNetfilterChains(t *testing.T) {
|
|||||||
t.Fatalf("list chains failed: %v", err)
|
t.Fatalf("list chains failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(chainsV4) != 3 {
|
if len(chainsV4) != 6 {
|
||||||
t.Fatalf("len(chainsV4) = %d, want 3", len(chainsV4))
|
t.Fatalf("len(chainsV4) = %d, want 6", len(chainsV4))
|
||||||
}
|
}
|
||||||
|
|
||||||
chainsV6, err := conn.ListChainsOfTableFamily(nftables.TableFamilyIPv6)
|
chainsV6, err := conn.ListChainsOfTableFamily(nftables.TableFamilyIPv6)
|
||||||
@ -470,8 +512,8 @@ func TestAddAndDelNetfilterChains(t *testing.T) {
|
|||||||
t.Fatalf("list chains failed: %v", err)
|
t.Fatalf("list chains failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(chainsV6) != 3 {
|
if len(chainsV6) != 6 {
|
||||||
t.Fatalf("len(chainsV6) = %d, want 3", len(chainsV6))
|
t.Fatalf("len(chainsV6) = %d, want 6", len(chainsV6))
|
||||||
}
|
}
|
||||||
|
|
||||||
runner.DelChains()
|
runner.DelChains()
|
||||||
@ -788,3 +830,87 @@ func TestNFTAddAndDelLoopbackRule(t *testing.T) {
|
|||||||
t.Fatalf("len(inputV4Rules) = %d, want 2", len(inputV4Rules))
|
t.Fatalf("len(inputV4Rules) = %d, want 2", len(inputV4Rules))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNFTAddAndDelHookRule(t *testing.T) {
|
||||||
|
if os.Geteuid() != 0 {
|
||||||
|
t.Skip(t.Name(), " requires privileges to create a namespace in order to run")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
conn := newSysConn(t)
|
||||||
|
runner := newFakeNftablesRunner(t, conn)
|
||||||
|
runner.AddChains()
|
||||||
|
defer runner.DelChains()
|
||||||
|
runner.AddHooks()
|
||||||
|
|
||||||
|
forwardChain, err := getChainFromTable(conn, runner.nft4.Filter, "FORWARD")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to get forwardChain: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
forwardChainRules, err := conn.GetRules(forwardChain.Table, forwardChain)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to get rules: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(forwardChainRules) != 1 {
|
||||||
|
t.Fatalf("expected 1 rule in FORWARD chain, got %v", len(forwardChainRules))
|
||||||
|
}
|
||||||
|
|
||||||
|
inputChain, err := getChainFromTable(conn, runner.nft4.Filter, "INPUT")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to get inputChain: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
inputChainRules, err := conn.GetRules(inputChain.Table, inputChain)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to get rules: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(inputChainRules) != 1 {
|
||||||
|
t.Fatalf("expected 1 rule in INPUT chain, got %v", len(inputChainRules))
|
||||||
|
}
|
||||||
|
|
||||||
|
postroutingChain, err := getChainFromTable(conn, runner.nft4.Nat, "POSTROUTING")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to get postroutingChain: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
postroutingChainRules, err := conn.GetRules(postroutingChain.Table, postroutingChain)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to get rules: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(postroutingChainRules) != 1 {
|
||||||
|
t.Fatalf("expected 1 rule in POSTROUTING chain, got %v", len(postroutingChainRules))
|
||||||
|
}
|
||||||
|
|
||||||
|
runner.DelHooks(t.Logf)
|
||||||
|
|
||||||
|
forwardChainRules, err = conn.GetRules(forwardChain.Table, forwardChain)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to get rules: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(forwardChainRules) != 0 {
|
||||||
|
t.Fatalf("expected 0 rule in FORWARD chain, got %v", len(forwardChainRules))
|
||||||
|
}
|
||||||
|
|
||||||
|
inputChainRules, err = conn.GetRules(inputChain.Table, inputChain)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to get rules: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(inputChainRules) != 0 {
|
||||||
|
t.Fatalf("expected 0 rule in INPUT chain, got %v", len(inputChainRules))
|
||||||
|
}
|
||||||
|
|
||||||
|
postroutingChainRules, err = conn.GetRules(postroutingChain.Table, postroutingChain)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to get rules: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(postroutingChainRules) != 0 {
|
||||||
|
t.Fatalf("expected 0 rule in POSTROUTING chain, got %v", len(postroutingChainRules))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user