util/linuxfw: fix IPv6 availability check for nftables (#12009)

* util/linuxfw: fix IPv6 NAT availability check for nftables

When running firewall in nftables mode,
there is no need for a separate NAT availability check
(unlike with iptables, there are no hosts that support nftables, but not IPv6 NAT - see tailscale/tailscale#11353).
This change fixes a firewall NAT availability check that was using the no-longer set ipv6NATAvailable field
by removing the field and using a method that, for nftables, just checks that IPv6 is available.

Updates tailscale/tailscale#12008

Signed-off-by: Irbe Krumina <irbe@tailscale.com>
This commit is contained in:
Irbe Krumina
2024-05-14 08:51:53 +01:00
committed by GitHub
parent 8aa5c3534d
commit 7ef2f72135
4 changed files with 106 additions and 77 deletions

View File

@@ -20,6 +20,8 @@ import (
"github.com/mdlayher/netlink"
"github.com/vishvananda/netns"
"tailscale.com/net/tsaddr"
"tailscale.com/tstest"
"tailscale.com/types/logger"
)
// nfdump returns a hexdump of 4 bytes per line (like nft --debug=all), allowing
@@ -503,19 +505,6 @@ func cleanupSysConn(t *testing.T, ns netns.NsHandle) {
}
}
func newFakeNftablesRunner(t *testing.T, conn *nftables.Conn) *nftablesRunner {
nft4 := &nftable{Proto: nftables.TableFamilyIPv4}
nft6 := &nftable{Proto: nftables.TableFamilyIPv6}
return &nftablesRunner{
conn: conn,
nft4: nft4,
nft6: nft6,
v6Available: true,
v6NATAvailable: true,
}
}
func checkChains(t *testing.T, conn *nftables.Conn, fam nftables.TableFamily, wantCount int) {
t.Helper()
got, err := conn.ListChainsOfTableFamily(fam)
@@ -526,42 +515,76 @@ func checkChains(t *testing.T, conn *nftables.Conn, fam nftables.TableFamily, wa
t.Fatalf("len(got) = %d, want %d", len(got), wantCount)
}
}
func checkTables(t *testing.T, conn *nftables.Conn, fam nftables.TableFamily, wantCount int) {
t.Helper()
got, err := conn.ListTablesOfFamily(fam)
if err != nil {
t.Fatalf("conn.ListTablesOfFamily(%v) failed: %v", fam, err)
}
if len(got) != wantCount {
t.Fatalf("len(got) = %d, want %d", len(got), wantCount)
}
}
func TestAddAndDelNetfilterChains(t *testing.T) {
conn := newSysConn(t)
checkChains(t, conn, nftables.TableFamilyIPv4, 0)
checkChains(t, conn, nftables.TableFamilyIPv6, 0)
runner := newFakeNftablesRunner(t, conn)
if err := runner.AddChains(); err != nil {
t.Fatalf("runner.AddChains() failed: %v", err)
type test struct {
hostHasIPv6 bool
initIPv4ChainCount int
initIPv6ChainCount int
ipv4TableCount int
ipv6TableCount int
ipv4ChainCount int
ipv6ChainCount int
ipv4ChainCountPostDelete int
ipv6ChainCountPostDelete int
}
tests := []test{
{
hostHasIPv6: true,
initIPv4ChainCount: 0,
initIPv6ChainCount: 0,
ipv4TableCount: 2,
ipv6TableCount: 2,
ipv4ChainCount: 6,
ipv6ChainCount: 6,
ipv4ChainCountPostDelete: 3,
ipv6ChainCountPostDelete: 3,
},
{ // host without IPv6 support
ipv4TableCount: 2,
ipv4ChainCount: 6,
ipv4ChainCountPostDelete: 3,
}}
for _, tt := range tests {
t.Logf("running a test case for IPv6 support: %v", tt.hostHasIPv6)
conn := newSysConn(t)
runner := newFakeNftablesRunnerWithConn(t, conn, tt.hostHasIPv6)
tables, err := conn.ListTables()
if err != nil {
t.Fatalf("conn.ListTables() failed: %v", err)
}
// Check that we start off with no chains.
checkChains(t, conn, nftables.TableFamilyIPv4, tt.initIPv4ChainCount)
checkChains(t, conn, nftables.TableFamilyIPv6, tt.initIPv6ChainCount)
if len(tables) != 4 {
t.Fatalf("len(tables) = %d, want 4", len(tables))
}
if err := runner.AddChains(); err != nil {
t.Fatalf("runner.AddChains() failed: %v", err)
}
checkChains(t, conn, nftables.TableFamilyIPv4, 6)
checkChains(t, conn, nftables.TableFamilyIPv6, 6)
// Check that the amount of tables for each IP family is as expected.
checkTables(t, conn, nftables.TableFamilyIPv4, tt.ipv4TableCount)
checkTables(t, conn, nftables.TableFamilyIPv6, tt.ipv6TableCount)
runner.DelChains()
// Check that the amount of chains for each IP family is as expected.
checkChains(t, conn, nftables.TableFamilyIPv4, tt.ipv4ChainCount)
checkChains(t, conn, nftables.TableFamilyIPv6, tt.ipv6ChainCount)
// The default chains should still be present.
checkChains(t, conn, nftables.TableFamilyIPv4, 3)
checkChains(t, conn, nftables.TableFamilyIPv6, 3)
if err := runner.DelChains(); err != nil {
t.Fatalf("runner.DelChains() failed: %v", err)
}
tables, err = conn.ListTables()
if err != nil {
t.Fatalf("conn.ListTables() failed: %v", err)
}
if len(tables) != 4 {
t.Fatalf("len(tables) = %d, want 4", len(tables))
// Test that the tables as well as the default chains are still present.
checkChains(t, conn, nftables.TableFamilyIPv4, tt.ipv4ChainCountPostDelete)
checkChains(t, conn, nftables.TableFamilyIPv6, tt.ipv6ChainCountPostDelete)
checkTables(t, conn, nftables.TableFamilyIPv4, tt.ipv4TableCount)
checkTables(t, conn, nftables.TableFamilyIPv6, tt.ipv6TableCount)
}
}
@@ -665,7 +688,8 @@ func checkChainRules(t *testing.T, conn *nftables.Conn, chain *nftables.Chain, w
func TestNFTAddAndDelNetfilterBase(t *testing.T) {
conn := newSysConn(t)
runner := newFakeNftablesRunner(t, conn)
runner := newFakeNftablesRunnerWithConn(t, conn, true)
if err := runner.AddChains(); err != nil {
t.Fatalf("AddChains() failed: %v", err)
}
@@ -759,7 +783,7 @@ func findLoopBackRule(conn *nftables.Conn, proto nftables.TableFamily, table *nf
func TestNFTAddAndDelLoopbackRule(t *testing.T) {
conn := newSysConn(t)
runner := newFakeNftablesRunner(t, conn)
runner := newFakeNftablesRunnerWithConn(t, conn, true)
if err := runner.AddChains(); err != nil {
t.Fatalf("AddChains() failed: %v", err)
}
@@ -817,7 +841,7 @@ func TestNFTAddAndDelLoopbackRule(t *testing.T) {
func TestNFTAddAndDelHookRule(t *testing.T) {
conn := newSysConn(t)
runner := newFakeNftablesRunner(t, conn)
runner := newFakeNftablesRunnerWithConn(t, conn, true)
if err := runner.AddChains(); err != nil {
t.Fatalf("AddChains() failed: %v", err)
}
@@ -868,11 +892,11 @@ func (t *testFWDetector) nftDetect() (int, error) {
// postrouting chains are cleaned up.
func TestCreateDummyPostroutingChains(t *testing.T) {
conn := newSysConn(t)
runner := newFakeNftablesRunner(t, conn)
runner := newFakeNftablesRunnerWithConn(t, conn, true)
if err := runner.createDummyPostroutingChains(); err != nil {
t.Fatalf("createDummyPostroutingChains() failed: %v", err)
}
for _, table := range runner.getNATTables() {
for _, table := range runner.getTables() {
nt, err := getTableIfExists(conn, table.Proto, tsDummyTableName)
if err != nil {
t.Fatalf("getTableIfExists() failed: %v", err)
@@ -929,3 +953,14 @@ func TestPickFirewallModeFromInstalledRules(t *testing.T) {
})
}
}
func newFakeNftablesRunnerWithConn(t *testing.T, conn *nftables.Conn, hasIPv6 bool) *nftablesRunner {
t.Helper()
if !hasIPv6 {
tstest.Replace(t, &checkIPv6ForTest, func(logger.Logf) error {
return errors.New("test: no IPv6")
})
}
return newNfTablesRunnerWithConn(t.Logf, conn)
}