From aad35843190a75a0ec0a5355cbd62f72cdeb8983 Mon Sep 17 00:00:00 2001 From: Maisem Ali Date: Wed, 11 Oct 2023 17:11:56 +0000 Subject: [PATCH] util/linuxfw: move fake runner into pkg This allows using the fake runner in different packages that need to manage filter rules. Updates #cleanup Signed-off-by: Maisem Ali --- util/linuxfw/fake.go | 126 +++++++++++++++++++++++ util/linuxfw/iptables_runner_test.go | 146 +-------------------------- 2 files changed, 131 insertions(+), 141 deletions(-) create mode 100644 util/linuxfw/fake.go diff --git a/util/linuxfw/fake.go b/util/linuxfw/fake.go new file mode 100644 index 000000000..e76431d00 --- /dev/null +++ b/util/linuxfw/fake.go @@ -0,0 +1,126 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux + +package linuxfw + +import ( + "errors" + "fmt" + "strings" +) + +type fakeIPTables struct { + n map[string][]string +} + +type fakeRule struct { + table, chain string + args []string +} + +func newFakeIPTables() *fakeIPTables { + return &fakeIPTables{ + n: map[string][]string{ + "filter/INPUT": nil, + "filter/OUTPUT": nil, + "filter/FORWARD": nil, + "nat/PREROUTING": nil, + "nat/OUTPUT": nil, + "nat/POSTROUTING": nil, + "mangle/FORWARD": nil, + }, + } +} + +func (n *fakeIPTables) Insert(table, chain string, pos int, args ...string) error { + k := table + "/" + chain + if rules, ok := n.n[k]; ok { + if pos > len(rules)+1 { + return fmt.Errorf("bad position %d in %s", pos, k) + } + rules = append(rules, "") + copy(rules[pos:], rules[pos-1:]) + rules[pos-1] = strings.Join(args, " ") + n.n[k] = rules + } else { + return fmt.Errorf("unknown table/chain %s", k) + } + return nil +} + +func (n *fakeIPTables) Append(table, chain string, args ...string) error { + k := table + "/" + chain + return n.Insert(table, chain, len(n.n[k])+1, args...) +} + +func (n *fakeIPTables) Exists(table, chain string, args ...string) (bool, error) { + k := table + "/" + chain + if rules, ok := n.n[k]; ok { + for _, rule := range rules { + if rule == strings.Join(args, " ") { + return true, nil + } + } + return false, nil + } else { + return false, fmt.Errorf("unknown table/chain %s", k) + } +} + +func (n *fakeIPTables) Delete(table, chain string, args ...string) error { + k := table + "/" + chain + if rules, ok := n.n[k]; ok { + for i, rule := range rules { + if rule == strings.Join(args, " ") { + rules = append(rules[:i], rules[i+1:]...) + n.n[k] = rules + return nil + } + } + return fmt.Errorf("delete of unknown rule %q from %s", strings.Join(args, " "), k) + } else { + return fmt.Errorf("unknown table/chain %s", k) + } +} + +func (n *fakeIPTables) ClearChain(table, chain string) error { + k := table + "/" + chain + if _, ok := n.n[k]; ok { + n.n[k] = nil + return nil + } else { + return errors.New("exitcode:1") + } +} + +func (n *fakeIPTables) NewChain(table, chain string) error { + k := table + "/" + chain + if _, ok := n.n[k]; ok { + return fmt.Errorf("table/chain %s already exists", k) + } + n.n[k] = nil + return nil +} + +func (n *fakeIPTables) DeleteChain(table, chain string) error { + k := table + "/" + chain + if rules, ok := n.n[k]; ok { + if len(rules) != 0 { + return fmt.Errorf("table/chain %s is not empty", k) + } + delete(n.n, k) + return nil + } else { + return fmt.Errorf("unknown table/chain %s", k) + } +} + +func NewFakeIPTablesRunner() *iptablesRunner { + ipt4 := newFakeIPTables() + ipt6 := newFakeIPTables() + + iptr := &iptablesRunner{ipt4, ipt6, true, true} + return iptr +} diff --git a/util/linuxfw/iptables_runner_test.go b/util/linuxfw/iptables_runner_test.go index d4c7c95f4..e678ddd6d 100644 --- a/util/linuxfw/iptables_runner_test.go +++ b/util/linuxfw/iptables_runner_test.go @@ -6,7 +6,6 @@ package linuxfw import ( - "errors" "net/netip" "strings" "testing" @@ -14,143 +13,8 @@ "tailscale.com/net/tsaddr" ) -var errExec = errors.New("execution failed") - -type fakeIPTables struct { - t *testing.T - n map[string][]string -} - -type fakeRule struct { - table, chain string - args []string -} - -func newIPTables(t *testing.T) *fakeIPTables { - return &fakeIPTables{ - t: t, - n: map[string][]string{ - "filter/INPUT": nil, - "filter/OUTPUT": nil, - "filter/FORWARD": nil, - "nat/PREROUTING": nil, - "nat/OUTPUT": nil, - "nat/POSTROUTING": nil, - }, - } -} - -func (n *fakeIPTables) Insert(table, chain string, pos int, args ...string) error { - k := table + "/" + chain - if rules, ok := n.n[k]; ok { - if pos > len(rules)+1 { - n.t.Errorf("bad position %d in %s", pos, k) - return errExec - } - rules = append(rules, "") - copy(rules[pos:], rules[pos-1:]) - rules[pos-1] = strings.Join(args, " ") - n.n[k] = rules - } else { - n.t.Errorf("unknown table/chain %s", k) - return errExec - } - return nil -} - -func (n *fakeIPTables) Append(table, chain string, args ...string) error { - k := table + "/" + chain - return n.Insert(table, chain, len(n.n[k])+1, args...) -} - -func (n *fakeIPTables) Exists(table, chain string, args ...string) (bool, error) { - k := table + "/" + chain - if rules, ok := n.n[k]; ok { - for _, rule := range rules { - if rule == strings.Join(args, " ") { - return true, nil - } - } - return false, nil - } else { - n.t.Logf("unknown table/chain %s", k) - return false, errExec - } -} - -func hasChain(n *fakeIPTables, table, chain string) bool { - k := table + "/" + chain - if _, ok := n.n[k]; ok { - return true - } else { - return false - } -} - -func (n *fakeIPTables) Delete(table, chain string, args ...string) error { - k := table + "/" + chain - if rules, ok := n.n[k]; ok { - for i, rule := range rules { - if rule == strings.Join(args, " ") { - rules = append(rules[:i], rules[i+1:]...) - n.n[k] = rules - return nil - } - } - n.t.Errorf("delete of unknown rule %q from %s", strings.Join(args, " "), k) - return errExec - } else { - n.t.Errorf("unknown table/chain %s", k) - return errExec - } -} - -func (n *fakeIPTables) ClearChain(table, chain string) error { - k := table + "/" + chain - if _, ok := n.n[k]; ok { - n.n[k] = nil - return nil - } else { - n.t.Logf("note: ClearChain: unknown table/chain %s", k) - return errors.New("exitcode:1") - } -} - -func (n *fakeIPTables) NewChain(table, chain string) error { - k := table + "/" + chain - if _, ok := n.n[k]; ok { - n.t.Errorf("table/chain %s already exists", k) - return errExec - } - n.n[k] = nil - return nil -} - -func (n *fakeIPTables) DeleteChain(table, chain string) error { - k := table + "/" + chain - if rules, ok := n.n[k]; ok { - if len(rules) != 0 { - n.t.Errorf("%s is not empty", k) - return errExec - } - delete(n.n, k) - return nil - } else { - n.t.Errorf("%s does not exist", k) - return errExec - } -} - -func newFakeIPTablesRunner(t *testing.T) *iptablesRunner { - ipt4 := newIPTables(t) - ipt6 := newIPTables(t) - - iptr := &iptablesRunner{ipt4, ipt6, true, true} - return iptr -} - func TestAddAndDeleteChains(t *testing.T) { - iptr := newFakeIPTablesRunner(t) + iptr := NewFakeIPTablesRunner() err := iptr.AddChains() if err != nil { t.Fatal(err) @@ -189,7 +53,7 @@ func TestAddAndDeleteChains(t *testing.T) { } func TestAddAndDeleteHooks(t *testing.T) { - iptr := newFakeIPTablesRunner(t) + iptr := NewFakeIPTablesRunner() // don't need to test what happens if the chains don't exist, because // this is handled by fake iptables, in realife iptables would return error. if err := iptr.AddChains(); err != nil { @@ -243,7 +107,7 @@ func TestAddAndDeleteHooks(t *testing.T) { } func TestAddAndDeleteBase(t *testing.T) { - iptr := newFakeIPTablesRunner(t) + iptr := NewFakeIPTablesRunner() tunname := "tun0" if err := iptr.AddChains(); err != nil { t.Fatal(err) @@ -306,7 +170,7 @@ func TestAddAndDeleteBase(t *testing.T) { } func TestAddAndDelLoopbackRule(t *testing.T) { - iptr := newFakeIPTablesRunner(t) + iptr := NewFakeIPTablesRunner() // We don't need to test for malformed addresses, AddLoopbackRule // takes in a netip.Addr, which is already valid. fakeAddrV4 := netip.MustParseAddr("192.168.0.2") @@ -377,7 +241,7 @@ func TestAddAndDelLoopbackRule(t *testing.T) { } func TestAddAndDelSNATRule(t *testing.T) { - iptr := newFakeIPTablesRunner(t) + iptr := NewFakeIPTablesRunner() if err := iptr.AddChains(); err != nil { t.Fatal(err)