util/linuxfw: move fake runner into pkg

This allows using the fake runner in different packages
that need to manage filter rules.

Updates #cleanup

Signed-off-by: Maisem Ali <maisem@tailscale.com>
This commit is contained in:
Maisem Ali 2023-10-11 17:11:56 +00:00 committed by Maisem Ali
parent fffafc65d6
commit aad3584319
2 changed files with 131 additions and 141 deletions

126
util/linuxfw/fake.go Normal file
View File

@ -0,0 +1,126 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build linux
package linuxfw
import (
"errors"
"fmt"
"strings"
)
type fakeIPTables struct {
n map[string][]string
}
type fakeRule struct {
table, chain string
args []string
}
func newFakeIPTables() *fakeIPTables {
return &fakeIPTables{
n: map[string][]string{
"filter/INPUT": nil,
"filter/OUTPUT": nil,
"filter/FORWARD": nil,
"nat/PREROUTING": nil,
"nat/OUTPUT": nil,
"nat/POSTROUTING": nil,
"mangle/FORWARD": nil,
},
}
}
func (n *fakeIPTables) Insert(table, chain string, pos int, args ...string) error {
k := table + "/" + chain
if rules, ok := n.n[k]; ok {
if pos > len(rules)+1 {
return fmt.Errorf("bad position %d in %s", pos, k)
}
rules = append(rules, "")
copy(rules[pos:], rules[pos-1:])
rules[pos-1] = strings.Join(args, " ")
n.n[k] = rules
} else {
return fmt.Errorf("unknown table/chain %s", k)
}
return nil
}
func (n *fakeIPTables) Append(table, chain string, args ...string) error {
k := table + "/" + chain
return n.Insert(table, chain, len(n.n[k])+1, args...)
}
func (n *fakeIPTables) Exists(table, chain string, args ...string) (bool, error) {
k := table + "/" + chain
if rules, ok := n.n[k]; ok {
for _, rule := range rules {
if rule == strings.Join(args, " ") {
return true, nil
}
}
return false, nil
} else {
return false, fmt.Errorf("unknown table/chain %s", k)
}
}
func (n *fakeIPTables) Delete(table, chain string, args ...string) error {
k := table + "/" + chain
if rules, ok := n.n[k]; ok {
for i, rule := range rules {
if rule == strings.Join(args, " ") {
rules = append(rules[:i], rules[i+1:]...)
n.n[k] = rules
return nil
}
}
return fmt.Errorf("delete of unknown rule %q from %s", strings.Join(args, " "), k)
} else {
return fmt.Errorf("unknown table/chain %s", k)
}
}
func (n *fakeIPTables) ClearChain(table, chain string) error {
k := table + "/" + chain
if _, ok := n.n[k]; ok {
n.n[k] = nil
return nil
} else {
return errors.New("exitcode:1")
}
}
func (n *fakeIPTables) NewChain(table, chain string) error {
k := table + "/" + chain
if _, ok := n.n[k]; ok {
return fmt.Errorf("table/chain %s already exists", k)
}
n.n[k] = nil
return nil
}
func (n *fakeIPTables) DeleteChain(table, chain string) error {
k := table + "/" + chain
if rules, ok := n.n[k]; ok {
if len(rules) != 0 {
return fmt.Errorf("table/chain %s is not empty", k)
}
delete(n.n, k)
return nil
} else {
return fmt.Errorf("unknown table/chain %s", k)
}
}
func NewFakeIPTablesRunner() *iptablesRunner {
ipt4 := newFakeIPTables()
ipt6 := newFakeIPTables()
iptr := &iptablesRunner{ipt4, ipt6, true, true}
return iptr
}

View File

@ -6,7 +6,6 @@
package linuxfw package linuxfw
import ( import (
"errors"
"net/netip" "net/netip"
"strings" "strings"
"testing" "testing"
@ -14,143 +13,8 @@
"tailscale.com/net/tsaddr" "tailscale.com/net/tsaddr"
) )
var errExec = errors.New("execution failed")
type fakeIPTables struct {
t *testing.T
n map[string][]string
}
type fakeRule struct {
table, chain string
args []string
}
func newIPTables(t *testing.T) *fakeIPTables {
return &fakeIPTables{
t: t,
n: map[string][]string{
"filter/INPUT": nil,
"filter/OUTPUT": nil,
"filter/FORWARD": nil,
"nat/PREROUTING": nil,
"nat/OUTPUT": nil,
"nat/POSTROUTING": nil,
},
}
}
func (n *fakeIPTables) Insert(table, chain string, pos int, args ...string) error {
k := table + "/" + chain
if rules, ok := n.n[k]; ok {
if pos > len(rules)+1 {
n.t.Errorf("bad position %d in %s", pos, k)
return errExec
}
rules = append(rules, "")
copy(rules[pos:], rules[pos-1:])
rules[pos-1] = strings.Join(args, " ")
n.n[k] = rules
} else {
n.t.Errorf("unknown table/chain %s", k)
return errExec
}
return nil
}
func (n *fakeIPTables) Append(table, chain string, args ...string) error {
k := table + "/" + chain
return n.Insert(table, chain, len(n.n[k])+1, args...)
}
func (n *fakeIPTables) Exists(table, chain string, args ...string) (bool, error) {
k := table + "/" + chain
if rules, ok := n.n[k]; ok {
for _, rule := range rules {
if rule == strings.Join(args, " ") {
return true, nil
}
}
return false, nil
} else {
n.t.Logf("unknown table/chain %s", k)
return false, errExec
}
}
func hasChain(n *fakeIPTables, table, chain string) bool {
k := table + "/" + chain
if _, ok := n.n[k]; ok {
return true
} else {
return false
}
}
func (n *fakeIPTables) Delete(table, chain string, args ...string) error {
k := table + "/" + chain
if rules, ok := n.n[k]; ok {
for i, rule := range rules {
if rule == strings.Join(args, " ") {
rules = append(rules[:i], rules[i+1:]...)
n.n[k] = rules
return nil
}
}
n.t.Errorf("delete of unknown rule %q from %s", strings.Join(args, " "), k)
return errExec
} else {
n.t.Errorf("unknown table/chain %s", k)
return errExec
}
}
func (n *fakeIPTables) ClearChain(table, chain string) error {
k := table + "/" + chain
if _, ok := n.n[k]; ok {
n.n[k] = nil
return nil
} else {
n.t.Logf("note: ClearChain: unknown table/chain %s", k)
return errors.New("exitcode:1")
}
}
func (n *fakeIPTables) NewChain(table, chain string) error {
k := table + "/" + chain
if _, ok := n.n[k]; ok {
n.t.Errorf("table/chain %s already exists", k)
return errExec
}
n.n[k] = nil
return nil
}
func (n *fakeIPTables) DeleteChain(table, chain string) error {
k := table + "/" + chain
if rules, ok := n.n[k]; ok {
if len(rules) != 0 {
n.t.Errorf("%s is not empty", k)
return errExec
}
delete(n.n, k)
return nil
} else {
n.t.Errorf("%s does not exist", k)
return errExec
}
}
func newFakeIPTablesRunner(t *testing.T) *iptablesRunner {
ipt4 := newIPTables(t)
ipt6 := newIPTables(t)
iptr := &iptablesRunner{ipt4, ipt6, true, true}
return iptr
}
func TestAddAndDeleteChains(t *testing.T) { func TestAddAndDeleteChains(t *testing.T) {
iptr := newFakeIPTablesRunner(t) iptr := NewFakeIPTablesRunner()
err := iptr.AddChains() err := iptr.AddChains()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -189,7 +53,7 @@ func TestAddAndDeleteChains(t *testing.T) {
} }
func TestAddAndDeleteHooks(t *testing.T) { func TestAddAndDeleteHooks(t *testing.T) {
iptr := newFakeIPTablesRunner(t) iptr := NewFakeIPTablesRunner()
// don't need to test what happens if the chains don't exist, because // don't need to test what happens if the chains don't exist, because
// this is handled by fake iptables, in realife iptables would return error. // this is handled by fake iptables, in realife iptables would return error.
if err := iptr.AddChains(); err != nil { if err := iptr.AddChains(); err != nil {
@ -243,7 +107,7 @@ func TestAddAndDeleteHooks(t *testing.T) {
} }
func TestAddAndDeleteBase(t *testing.T) { func TestAddAndDeleteBase(t *testing.T) {
iptr := newFakeIPTablesRunner(t) iptr := NewFakeIPTablesRunner()
tunname := "tun0" tunname := "tun0"
if err := iptr.AddChains(); err != nil { if err := iptr.AddChains(); err != nil {
t.Fatal(err) t.Fatal(err)
@ -306,7 +170,7 @@ func TestAddAndDeleteBase(t *testing.T) {
} }
func TestAddAndDelLoopbackRule(t *testing.T) { func TestAddAndDelLoopbackRule(t *testing.T) {
iptr := newFakeIPTablesRunner(t) iptr := NewFakeIPTablesRunner()
// We don't need to test for malformed addresses, AddLoopbackRule // We don't need to test for malformed addresses, AddLoopbackRule
// takes in a netip.Addr, which is already valid. // takes in a netip.Addr, which is already valid.
fakeAddrV4 := netip.MustParseAddr("192.168.0.2") fakeAddrV4 := netip.MustParseAddr("192.168.0.2")
@ -377,7 +241,7 @@ func TestAddAndDelLoopbackRule(t *testing.T) {
} }
func TestAddAndDelSNATRule(t *testing.T) { func TestAddAndDelSNATRule(t *testing.T) {
iptr := newFakeIPTablesRunner(t) iptr := NewFakeIPTablesRunner()
if err := iptr.AddChains(); err != nil { if err := iptr.AddChains(); err != nil {
t.Fatal(err) t.Fatal(err)