mirror of
https://github.com/tailscale/tailscale.git
synced 2025-04-20 05:31:40 +00:00
util/linuxfw: fix broken tests
These tests were broken at HEAD. CI currently does not run these as root, will figure out how to do that in a followup. Updates #5621 Updates #8555 Updates #8762 Signed-off-by: Maisem Ali <maisem@tailscale.com>
This commit is contained in:
parent
a8fbe284b2
commit
b47cf04624
@ -474,6 +474,10 @@ func TestAddMatchSubnetRouteMarkRuleAccept(t *testing.T) {
|
|||||||
|
|
||||||
func newSysConn(t *testing.T) *nftables.Conn {
|
func newSysConn(t *testing.T) *nftables.Conn {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
if os.Geteuid() != 0 {
|
||||||
|
t.Skip(t.Name(), " requires privileges to create a namespace in order to run")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
runtime.LockOSThread()
|
runtime.LockOSThread()
|
||||||
|
|
||||||
@ -512,12 +516,21 @@ func newFakeNftablesRunner(t *testing.T, conn *nftables.Conn) *nftablesRunner {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAddAndDelNetfilterChains(t *testing.T) {
|
func checkChains(t *testing.T, conn *nftables.Conn, fam nftables.TableFamily, wantCount int) {
|
||||||
if os.Geteuid() != 0 {
|
t.Helper()
|
||||||
t.Skip(t.Name(), " requires privileges to create a namespace in order to run")
|
got, err := conn.ListChainsOfTableFamily(fam)
|
||||||
return
|
if err != nil {
|
||||||
|
t.Fatalf("conn.ListChainsOfTableFamily(%v) failed: %v", fam, err)
|
||||||
}
|
}
|
||||||
|
if len(got) != wantCount {
|
||||||
|
t.Fatalf("len(got) = %d, want %d", len(got), wantCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddAndDelNetfilterChains(t *testing.T) {
|
||||||
conn := newSysConn(t)
|
conn := newSysConn(t)
|
||||||
|
checkChains(t, conn, nftables.TableFamilyIPv4, 0)
|
||||||
|
checkChains(t, conn, nftables.TableFamilyIPv6, 0)
|
||||||
|
|
||||||
runner := newFakeNftablesRunner(t, conn)
|
runner := newFakeNftablesRunner(t, conn)
|
||||||
runner.AddChains()
|
runner.AddChains()
|
||||||
@ -531,33 +544,22 @@ func TestAddAndDelNetfilterChains(t *testing.T) {
|
|||||||
t.Fatalf("len(tables) = %d, want 4", len(tables))
|
t.Fatalf("len(tables) = %d, want 4", len(tables))
|
||||||
}
|
}
|
||||||
|
|
||||||
chainsV4, err := conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
|
checkChains(t, conn, nftables.TableFamilyIPv4, 6)
|
||||||
if err != nil {
|
checkChains(t, conn, nftables.TableFamilyIPv6, 6)
|
||||||
t.Fatalf("list chains failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(chainsV4) != 6 {
|
|
||||||
t.Fatalf("len(chainsV4) = %d, want 6", len(chainsV4))
|
|
||||||
}
|
|
||||||
|
|
||||||
chainsV6, err := conn.ListChainsOfTableFamily(nftables.TableFamilyIPv6)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("list chains failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(chainsV6) != 6 {
|
|
||||||
t.Fatalf("len(chainsV6) = %d, want 6", len(chainsV6))
|
|
||||||
}
|
|
||||||
|
|
||||||
runner.DelChains()
|
runner.DelChains()
|
||||||
|
|
||||||
|
// The default chains should still be present.
|
||||||
|
checkChains(t, conn, nftables.TableFamilyIPv4, 3)
|
||||||
|
checkChains(t, conn, nftables.TableFamilyIPv6, 3)
|
||||||
|
|
||||||
tables, err = conn.ListTables()
|
tables, err = conn.ListTables()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("conn.ListTables() failed: %v", err)
|
t.Fatalf("conn.ListTables() failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(tables) != 0 {
|
if len(tables) != 4 {
|
||||||
t.Fatalf("len(tables) = %d, want 0", len(tables))
|
t.Fatalf("len(tables) = %d, want 4", len(tables))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -646,12 +648,19 @@ func findCommonBaseRules(
|
|||||||
return get, nil
|
return get, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNFTAddAndDelNetfilterBase(t *testing.T) {
|
// checkChainRules verifies that the chain has the expected number of rules.
|
||||||
if os.Geteuid() != 0 {
|
func checkChainRules(t *testing.T, conn *nftables.Conn, chain *nftables.Chain, wantCount int) {
|
||||||
t.Skip(t.Name(), " requires privileges to create a namespace in order to run")
|
t.Helper()
|
||||||
return
|
got, err := conn.GetRules(chain.Table, chain)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("conn.GetRules() failed: %v", err)
|
||||||
}
|
}
|
||||||
|
if len(got) != wantCount {
|
||||||
|
t.Fatalf("got = %d, want %d", len(got), wantCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNFTAddAndDelNetfilterBase(t *testing.T) {
|
||||||
conn := newSysConn(t)
|
conn := newSysConn(t)
|
||||||
|
|
||||||
runner := newFakeNftablesRunner(t, conn)
|
runner := newFakeNftablesRunner(t, conn)
|
||||||
@ -664,30 +673,9 @@ func TestNFTAddAndDelNetfilterBase(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("getTsChains() failed: %v", err)
|
t.Fatalf("getTsChains() failed: %v", err)
|
||||||
}
|
}
|
||||||
|
checkChainRules(t, conn, inputV4, 3)
|
||||||
inputV4Rules, err := conn.GetRules(runner.nft4.Filter, inputV4)
|
checkChainRules(t, conn, forwardV4, 4)
|
||||||
if err != nil {
|
checkChainRules(t, conn, postroutingV4, 0)
|
||||||
t.Fatalf("conn.GetRules() failed: %v", err)
|
|
||||||
}
|
|
||||||
if len(inputV4Rules) != 2 {
|
|
||||||
t.Fatalf("len(inputV4Rules) = %d, want 2", len(inputV4Rules))
|
|
||||||
}
|
|
||||||
|
|
||||||
forwardV4Rules, err := conn.GetRules(runner.nft4.Filter, forwardV4)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("conn.GetRules() failed: %v", err)
|
|
||||||
}
|
|
||||||
if len(forwardV4Rules) != 4 {
|
|
||||||
t.Fatalf("len(forwardV4Rules) = %d, want 4", len(forwardV4Rules))
|
|
||||||
}
|
|
||||||
|
|
||||||
postroutingV4Rules, err := conn.GetRules(runner.nft4.Nat, postroutingV4)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("conn.GetRules() failed: %v", err)
|
|
||||||
}
|
|
||||||
if len(postroutingV4Rules) != 0 {
|
|
||||||
t.Fatalf("len(postroutingV4Rules) = %d, want 0", len(postroutingV4Rules))
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = findV4BaseRules(conn, inputV4, forwardV4, "testTunn")
|
_, err = findV4BaseRules(conn, inputV4, forwardV4, "testTunn")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -703,30 +691,9 @@ func TestNFTAddAndDelNetfilterBase(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("getTsChains() failed: %v", err)
|
t.Fatalf("getTsChains() failed: %v", err)
|
||||||
}
|
}
|
||||||
|
checkChainRules(t, conn, inputV6, 3)
|
||||||
inputV6Rules, err := conn.GetRules(runner.nft6.Filter, inputV6)
|
checkChainRules(t, conn, forwardV6, 4)
|
||||||
if err != nil {
|
checkChainRules(t, conn, postroutingV6, 0)
|
||||||
t.Fatalf("conn.GetRules() failed: %v", err)
|
|
||||||
}
|
|
||||||
if len(inputV6Rules) != 0 {
|
|
||||||
t.Fatalf("len(inputV6Rules) = %d, want 0", len(inputV4Rules))
|
|
||||||
}
|
|
||||||
|
|
||||||
forwardV6Rules, err := conn.GetRules(runner.nft6.Filter, forwardV6)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("conn.GetRules() failed: %v", err)
|
|
||||||
}
|
|
||||||
if len(forwardV6Rules) != 3 {
|
|
||||||
t.Fatalf("len(forwardV6Rules) = %d, want 3", len(forwardV4Rules))
|
|
||||||
}
|
|
||||||
|
|
||||||
postroutingV6Rules, err := conn.GetRules(runner.nft6.Nat, postroutingV6)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("conn.GetRules() failed: %v", err)
|
|
||||||
}
|
|
||||||
if len(postroutingV6Rules) != 0 {
|
|
||||||
t.Fatalf("len(postroutingV6Rules) = %d, want 0", len(postroutingV4Rules))
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = findCommonBaseRules(conn, forwardV6, "testTunn")
|
_, err = findCommonBaseRules(conn, forwardV6, "testTunn")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -740,13 +707,7 @@ func TestNFTAddAndDelNetfilterBase(t *testing.T) {
|
|||||||
t.Fatalf("conn.ListChains() failed: %v", err)
|
t.Fatalf("conn.ListChains() failed: %v", err)
|
||||||
}
|
}
|
||||||
for _, chain := range chains {
|
for _, chain := range chains {
|
||||||
chainRules, err := conn.GetRules(chain.Table, chain)
|
checkChainRules(t, conn, chain, 0)
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("conn.GetRules() failed: %v", err)
|
|
||||||
}
|
|
||||||
if len(chainRules) != 0 {
|
|
||||||
t.Fatalf("len(chainRules) = %d, want 0", len(chainRules))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -790,36 +751,36 @@ func findLoopBackRule(conn *nftables.Conn, proto nftables.TableFamily, table *nf
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestNFTAddAndDelLoopbackRule(t *testing.T) {
|
func TestNFTAddAndDelLoopbackRule(t *testing.T) {
|
||||||
if os.Geteuid() != 0 {
|
|
||||||
t.Skip(t.Name(), " requires privileges to create a namespace in order to run")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
conn := newSysConn(t)
|
conn := newSysConn(t)
|
||||||
|
|
||||||
runner := newFakeNftablesRunner(t, conn)
|
runner := newFakeNftablesRunner(t, conn)
|
||||||
runner.AddChains()
|
runner.AddChains()
|
||||||
defer runner.DelChains()
|
defer runner.DelChains()
|
||||||
runner.AddBase("testTunn")
|
|
||||||
defer runner.DelBase()
|
|
||||||
|
|
||||||
addr := netip.MustParseAddr("192.168.0.2")
|
|
||||||
addrV6 := netip.MustParseAddr("2001:db8::2")
|
|
||||||
runner.AddLoopbackRule(addr)
|
|
||||||
runner.AddLoopbackRule(addrV6)
|
|
||||||
|
|
||||||
inputV4, _, _, err := getTsChains(conn, nftables.TableFamilyIPv4)
|
inputV4, _, _, err := getTsChains(conn, nftables.TableFamilyIPv4)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("getTsChains() failed: %v", err)
|
t.Fatalf("getTsChains() failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
inputV4Rules, err := conn.GetRules(runner.nft4.Filter, inputV4)
|
inputV6, _, _, err := getTsChains(conn, nftables.TableFamilyIPv6)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("conn.GetRules() failed: %v", err)
|
t.Fatalf("getTsChains() failed: %v", err)
|
||||||
}
|
|
||||||
if len(inputV4Rules) != 3 {
|
|
||||||
t.Fatalf("len(inputV4Rules) = %d, want 3", len(inputV4Rules))
|
|
||||||
}
|
}
|
||||||
|
checkChainRules(t, conn, inputV4, 0)
|
||||||
|
checkChainRules(t, conn, inputV6, 0)
|
||||||
|
|
||||||
|
runner.AddBase("testTunn")
|
||||||
|
defer runner.DelBase()
|
||||||
|
checkChainRules(t, conn, inputV4, 3)
|
||||||
|
checkChainRules(t, conn, inputV6, 3)
|
||||||
|
|
||||||
|
addr := netip.MustParseAddr("192.168.0.2")
|
||||||
|
addrV6 := netip.MustParseAddr("2001:db8::2")
|
||||||
|
runner.AddLoopbackRule(addr)
|
||||||
|
runner.AddLoopbackRule(addrV6)
|
||||||
|
|
||||||
|
checkChainRules(t, conn, inputV4, 4)
|
||||||
|
checkChainRules(t, conn, inputV6, 4)
|
||||||
|
|
||||||
existingLoopBackRule, err := findLoopBackRule(conn, nftables.TableFamilyIPv4, runner.nft4.Filter, inputV4, addr)
|
existingLoopBackRule, err := findLoopBackRule(conn, nftables.TableFamilyIPv4, runner.nft4.Filter, inputV4, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -830,19 +791,6 @@ func TestNFTAddAndDelLoopbackRule(t *testing.T) {
|
|||||||
t.Fatalf("existingLoopBackRule.Handle = %d, want 0", existingLoopBackRule.Handle)
|
t.Fatalf("existingLoopBackRule.Handle = %d, want 0", existingLoopBackRule.Handle)
|
||||||
}
|
}
|
||||||
|
|
||||||
inputV6, _, _, err := getTsChains(conn, nftables.TableFamilyIPv6)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("getTsChains() failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
inputV6Rules, err := conn.GetRules(runner.nft6.Filter, inputV4)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("conn.GetRules() failed: %v", err)
|
|
||||||
}
|
|
||||||
if len(inputV6Rules) != 1 {
|
|
||||||
t.Fatalf("len(inputV4Rules) = %d, want 1", len(inputV4Rules))
|
|
||||||
}
|
|
||||||
|
|
||||||
existingLoopBackRuleV6, err := findLoopBackRule(conn, nftables.TableFamilyIPv6, runner.nft6.Filter, inputV6, addrV6)
|
existingLoopBackRuleV6, err := findLoopBackRule(conn, nftables.TableFamilyIPv6, runner.nft6.Filter, inputV6, addrV6)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("findLoopBackRule() failed: %v", err)
|
t.Fatalf("findLoopBackRule() failed: %v", err)
|
||||||
@ -855,21 +803,11 @@ func TestNFTAddAndDelLoopbackRule(t *testing.T) {
|
|||||||
runner.DelLoopbackRule(addr)
|
runner.DelLoopbackRule(addr)
|
||||||
runner.DelLoopbackRule(addrV6)
|
runner.DelLoopbackRule(addrV6)
|
||||||
|
|
||||||
inputV4Rules, err = conn.GetRules(runner.nft4.Filter, inputV4)
|
checkChainRules(t, conn, inputV4, 3)
|
||||||
if err != nil {
|
checkChainRules(t, conn, inputV6, 3)
|
||||||
t.Fatalf("conn.GetRules() failed: %v", err)
|
|
||||||
}
|
|
||||||
if len(inputV4Rules) != 2 {
|
|
||||||
t.Fatalf("len(inputV4Rules) = %d, want 2", len(inputV4Rules))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNFTAddAndDelHookRule(t *testing.T) {
|
func TestNFTAddAndDelHookRule(t *testing.T) {
|
||||||
if os.Geteuid() != 0 {
|
|
||||||
t.Skip(t.Name(), " requires privileges to create a namespace in order to run")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
conn := newSysConn(t)
|
conn := newSysConn(t)
|
||||||
runner := newFakeNftablesRunner(t, conn)
|
runner := newFakeNftablesRunner(t, conn)
|
||||||
runner.AddChains()
|
runner.AddChains()
|
||||||
@ -880,72 +818,24 @@ func TestNFTAddAndDelHookRule(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to get forwardChain: %v", err)
|
t.Fatalf("failed to get forwardChain: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
forwardChainRules, err := conn.GetRules(forwardChain.Table, forwardChain)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to get rules: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(forwardChainRules) != 1 {
|
|
||||||
t.Fatalf("expected 1 rule in FORWARD chain, got %v", len(forwardChainRules))
|
|
||||||
}
|
|
||||||
|
|
||||||
inputChain, err := getChainFromTable(conn, runner.nft4.Filter, "INPUT")
|
inputChain, err := getChainFromTable(conn, runner.nft4.Filter, "INPUT")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to get inputChain: %v", err)
|
t.Fatalf("failed to get inputChain: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
inputChainRules, err := conn.GetRules(inputChain.Table, inputChain)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to get rules: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(inputChainRules) != 1 {
|
|
||||||
t.Fatalf("expected 1 rule in INPUT chain, got %v", len(inputChainRules))
|
|
||||||
}
|
|
||||||
|
|
||||||
postroutingChain, err := getChainFromTable(conn, runner.nft4.Nat, "POSTROUTING")
|
postroutingChain, err := getChainFromTable(conn, runner.nft4.Nat, "POSTROUTING")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to get postroutingChain: %v", err)
|
t.Fatalf("failed to get postroutingChain: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
postroutingChainRules, err := conn.GetRules(postroutingChain.Table, postroutingChain)
|
checkChainRules(t, conn, forwardChain, 1)
|
||||||
if err != nil {
|
checkChainRules(t, conn, inputChain, 1)
|
||||||
t.Fatalf("failed to get rules: %v", err)
|
checkChainRules(t, conn, postroutingChain, 1)
|
||||||
}
|
|
||||||
|
|
||||||
if len(postroutingChainRules) != 1 {
|
|
||||||
t.Fatalf("expected 1 rule in POSTROUTING chain, got %v", len(postroutingChainRules))
|
|
||||||
}
|
|
||||||
|
|
||||||
runner.DelHooks(t.Logf)
|
runner.DelHooks(t.Logf)
|
||||||
|
|
||||||
forwardChainRules, err = conn.GetRules(forwardChain.Table, forwardChain)
|
checkChainRules(t, conn, forwardChain, 0)
|
||||||
if err != nil {
|
checkChainRules(t, conn, inputChain, 0)
|
||||||
t.Fatalf("failed to get rules: %v", err)
|
checkChainRules(t, conn, postroutingChain, 0)
|
||||||
}
|
|
||||||
|
|
||||||
if len(forwardChainRules) != 0 {
|
|
||||||
t.Fatalf("expected 0 rule in FORWARD chain, got %v", len(forwardChainRules))
|
|
||||||
}
|
|
||||||
|
|
||||||
inputChainRules, err = conn.GetRules(inputChain.Table, inputChain)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to get rules: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(inputChainRules) != 0 {
|
|
||||||
t.Fatalf("expected 0 rule in INPUT chain, got %v", len(inputChainRules))
|
|
||||||
}
|
|
||||||
|
|
||||||
postroutingChainRules, err = conn.GetRules(postroutingChain.Table, postroutingChain)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to get rules: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(postroutingChainRules) != 0 {
|
|
||||||
t.Fatalf("expected 0 rule in POSTROUTING chain, got %v", len(postroutingChainRules))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type testFWDetector struct {
|
type testFWDetector struct {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user