mirror of
https://github.com/tailscale/tailscale.git
synced 2025-08-11 21:27:31 +00:00
util/linuxfw: decoupling IPTables logic from linux router
This change is introducing new netfilterRunner interface and moving iptables manipulation to a lower leveled iptables runner. For #391 Signed-off-by: KevinLiang10 <kevinliang@tailscale.com>
This commit is contained in:

committed by
KevinLiang10

parent
9c64e015e5
commit
243ce6ccc1
@@ -22,8 +22,10 @@ import (
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.org/x/exp/slices"
|
||||
"tailscale.com/net/netmon"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/tstest"
|
||||
"tailscale.com/types/logger"
|
||||
"tailscale.com/util/linuxfw"
|
||||
)
|
||||
|
||||
func TestRouterStates(t *testing.T) {
|
||||
@@ -328,7 +330,7 @@ ip route add throw 192.168.0.0/24 table 52` + basic,
|
||||
defer mon.Close()
|
||||
|
||||
fake := NewFakeOS(t)
|
||||
router, err := newUserspaceRouterAdvanced(t.Logf, "tailscale0", mon, fake.netfilter4, fake.netfilter6, fake, true, true)
|
||||
router, err := newUserspaceRouterAdvanced(t.Logf, "tailscale0", mon, fake.nfr, fake)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create router: %v", err)
|
||||
}
|
||||
@@ -362,15 +364,25 @@ ip route add throw 192.168.0.0/24 table 52` + basic,
|
||||
}
|
||||
}
|
||||
|
||||
type fakeNetfilter struct {
|
||||
t *testing.T
|
||||
n map[string][]string
|
||||
type fakeIPTablesRunner struct {
|
||||
t *testing.T
|
||||
ipt4 map[string][]string
|
||||
ipt6 map[string][]string
|
||||
//we always assume ipv6 and ipv6 nat are enabled when testing
|
||||
}
|
||||
|
||||
func newNetfilter(t *testing.T) *fakeNetfilter {
|
||||
return &fakeNetfilter{
|
||||
func newIPTablesRunner(t *testing.T) netfilterRunner {
|
||||
return &fakeIPTablesRunner{
|
||||
t: t,
|
||||
n: map[string][]string{
|
||||
ipt4: map[string][]string{
|
||||
"filter/INPUT": nil,
|
||||
"filter/OUTPUT": nil,
|
||||
"filter/FORWARD": nil,
|
||||
"nat/PREROUTING": nil,
|
||||
"nat/OUTPUT": nil,
|
||||
"nat/POSTROUTING": nil,
|
||||
},
|
||||
ipt6: map[string][]string{
|
||||
"filter/INPUT": nil,
|
||||
"filter/OUTPUT": nil,
|
||||
"filter/FORWARD": nil,
|
||||
@@ -381,115 +393,222 @@ func newNetfilter(t *testing.T) *fakeNetfilter {
|
||||
}
|
||||
}
|
||||
|
||||
func (n *fakeNetfilter) Insert(table, chain string, pos int, args ...string) error {
|
||||
k := table + "/" + chain
|
||||
if rules, ok := n.n[k]; ok {
|
||||
if pos > len(rules)+1 {
|
||||
n.t.Errorf("bad position %d in %s", pos, k)
|
||||
return errExec
|
||||
func insertRule(n *fakeIPTablesRunner, curIPT map[string][]string, chain, newRule string) error {
|
||||
// Get current rules for filter/ts-input chain with according IP version
|
||||
curTSInputRules, ok := curIPT[chain]
|
||||
if !ok {
|
||||
n.t.Fatalf("no %s chain exists", chain)
|
||||
return fmt.Errorf("no %s chain exists", chain)
|
||||
}
|
||||
|
||||
// Add new rule to top of filter/ts-input
|
||||
curTSInputRules = append(curTSInputRules, "")
|
||||
copy(curTSInputRules[1:], curTSInputRules)
|
||||
curTSInputRules[0] = newRule
|
||||
curIPT[chain] = curTSInputRules
|
||||
return nil
|
||||
}
|
||||
|
||||
func appendRule(n *fakeIPTablesRunner, curIPT map[string][]string, chain, newRule string) error {
|
||||
// Get current rules for filter/ts-input chain with according IP version
|
||||
curTSInputRules, ok := curIPT[chain]
|
||||
if !ok {
|
||||
n.t.Fatalf("no %s chain exists", chain)
|
||||
return fmt.Errorf("no %s chain exists", chain)
|
||||
}
|
||||
|
||||
// Add new rule to end of filter/ts-input
|
||||
curTSInputRules = append(curTSInputRules, newRule)
|
||||
curIPT[chain] = curTSInputRules
|
||||
return nil
|
||||
}
|
||||
|
||||
func deleteRule(n *fakeIPTablesRunner, curIPT map[string][]string, chain, delRule string) error {
|
||||
// Get current rules for filter/ts-input chain with according IP version
|
||||
curTSInputRules, ok := curIPT[chain]
|
||||
if !ok {
|
||||
n.t.Fatalf("no %s chain exists", chain)
|
||||
return fmt.Errorf("no %s chain exists", chain)
|
||||
}
|
||||
|
||||
// Remove rule from filter/ts-input
|
||||
for i, rule := range curTSInputRules {
|
||||
if rule == delRule {
|
||||
curTSInputRules = append(curTSInputRules[:i], curTSInputRules[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
curIPT[chain] = curTSInputRules
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *fakeIPTablesRunner) AddLoopbackRule(addr netip.Addr) error {
|
||||
curIPT := n.ipt4
|
||||
if addr.Is6() {
|
||||
curIPT = n.ipt6
|
||||
}
|
||||
newRule := fmt.Sprintf("-i lo -s %s -j ACCEPT", addr.String())
|
||||
|
||||
return insertRule(n, curIPT, "filter/ts-input", newRule)
|
||||
}
|
||||
|
||||
func (n *fakeIPTablesRunner) AddBase(tunname string) error {
|
||||
if err := n.AddBase4(tunname); err != nil {
|
||||
return err
|
||||
}
|
||||
if n.HasIPV6() {
|
||||
if err := n.AddBase6(tunname); err != nil {
|
||||
return err
|
||||
}
|
||||
rules = append(rules, "")
|
||||
copy(rules[pos:], rules[pos-1:])
|
||||
rules[pos-1] = strings.Join(args, " ")
|
||||
n.n[k] = rules
|
||||
} else {
|
||||
n.t.Errorf("unknown table/chain %s", k)
|
||||
return errExec
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *fakeNetfilter) Append(table, chain string, args ...string) error {
|
||||
k := table + "/" + chain
|
||||
return n.Insert(table, chain, len(n.n[k])+1, args...)
|
||||
}
|
||||
|
||||
func (n *fakeNetfilter) Exists(table, chain string, args ...string) (bool, error) {
|
||||
k := table + "/" + chain
|
||||
if rules, ok := n.n[k]; ok {
|
||||
for _, rule := range rules {
|
||||
if rule == strings.Join(args, " ") {
|
||||
return true, nil
|
||||
}
|
||||
func (n *fakeIPTablesRunner) AddBase4(tunname string) error {
|
||||
curIPT := n.ipt4
|
||||
newRules := []struct{ chain, rule string }{
|
||||
{"filter/ts-input", fmt.Sprintf("! -i %s -s %s -j RETURN", tunname, tsaddr.ChromeOSVMRange().String())},
|
||||
{"filter/ts-input", fmt.Sprintf("! -i %s -s %s -j DROP", tunname, tsaddr.CGNATRange().String())},
|
||||
{"filter/ts-forward", fmt.Sprintf("-i %s -j MARK --set-mark %s/%s", tunname, linuxfw.TailscaleSubnetRouteMark, linuxfw.TailscaleFwmarkMask)},
|
||||
{"filter/ts-forward", fmt.Sprintf("-m mark --mark %s/%s -j ACCEPT", linuxfw.TailscaleSubnetRouteMark, linuxfw.TailscaleFwmarkMask)},
|
||||
{"filter/ts-forward", fmt.Sprintf("-o %s -s %s -j DROP", tunname, tsaddr.CGNATRange().String())},
|
||||
{"filter/ts-forward", fmt.Sprintf("-o %s -j ACCEPT", tunname)},
|
||||
}
|
||||
for _, rule := range newRules {
|
||||
if err := appendRule(n, curIPT, rule.chain, rule.rule); err != nil {
|
||||
return fmt.Errorf("add rule %q to chain %q: %w", rule.rule, rule.chain, err)
|
||||
}
|
||||
return false, nil
|
||||
} else {
|
||||
n.t.Errorf("unknown table/chain %s", k)
|
||||
return false, errExec
|
||||
}
|
||||
}
|
||||
|
||||
func (n *fakeNetfilter) Delete(table, chain string, args ...string) error {
|
||||
k := table + "/" + chain
|
||||
if rules, ok := n.n[k]; ok {
|
||||
for i, rule := range rules {
|
||||
if rule == strings.Join(args, " ") {
|
||||
rules = append(rules[:i], rules[i+1:]...)
|
||||
n.n[k] = rules
|
||||
return nil
|
||||
}
|
||||
}
|
||||
n.t.Errorf("delete of unknown rule %q from %s", strings.Join(args, " "), k)
|
||||
return errExec
|
||||
} else {
|
||||
n.t.Errorf("unknown table/chain %s", k)
|
||||
return errExec
|
||||
}
|
||||
}
|
||||
|
||||
func (n *fakeNetfilter) ClearChain(table, chain string) error {
|
||||
k := table + "/" + chain
|
||||
if _, ok := n.n[k]; ok {
|
||||
n.n[k] = nil
|
||||
return nil
|
||||
} else {
|
||||
n.t.Logf("note: ClearChain: unknown table/chain %s", k)
|
||||
return errors.New("exitcode:1")
|
||||
}
|
||||
}
|
||||
|
||||
func (n *fakeNetfilter) NewChain(table, chain string) error {
|
||||
k := table + "/" + chain
|
||||
if _, ok := n.n[k]; ok {
|
||||
n.t.Errorf("table/chain %s already exists", k)
|
||||
return errExec
|
||||
}
|
||||
n.n[k] = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *fakeNetfilter) DeleteChain(table, chain string) error {
|
||||
k := table + "/" + chain
|
||||
if rules, ok := n.n[k]; ok {
|
||||
if len(rules) != 0 {
|
||||
n.t.Errorf("%s is not empty", k)
|
||||
return errExec
|
||||
}
|
||||
delete(n.n, k)
|
||||
return nil
|
||||
} else {
|
||||
n.t.Errorf("%s does not exist", k)
|
||||
return errExec
|
||||
func (n *fakeIPTablesRunner) AddBase6(tunname string) error {
|
||||
curIPT := n.ipt6
|
||||
newRules := []struct{ chain, rule string }{
|
||||
{"filter/ts-forward", fmt.Sprintf("-i %s -j MARK --set-mark %s/%s", tunname, linuxfw.TailscaleSubnetRouteMark, linuxfw.TailscaleFwmarkMask)},
|
||||
{"filter/ts-forward", fmt.Sprintf("-m mark --mark %s/%s -j ACCEPT", linuxfw.TailscaleSubnetRouteMark, linuxfw.TailscaleFwmarkMask)},
|
||||
{"filter/ts-forward", fmt.Sprintf("-o %s -j ACCEPT", tunname)},
|
||||
}
|
||||
for _, rule := range newRules {
|
||||
if err := appendRule(n, curIPT, rule.chain, rule.rule); err != nil {
|
||||
return fmt.Errorf("add rule %q to chain %q: %w", rule.rule, rule.chain, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *fakeIPTablesRunner) DelLoopbackRule(addr netip.Addr) error {
|
||||
curIPT := n.ipt4
|
||||
if addr.Is6() {
|
||||
curIPT = n.ipt6
|
||||
}
|
||||
|
||||
delRule := fmt.Sprintf("-i lo -s %s -j ACCEPT", addr.String())
|
||||
|
||||
return deleteRule(n, curIPT, "filter/ts-input", delRule)
|
||||
}
|
||||
|
||||
func (n *fakeIPTablesRunner) AddHooks() error {
|
||||
newRules := []struct{ chain, rule string }{
|
||||
{"filter/INPUT", "-j ts-input"},
|
||||
{"filter/FORWARD", "-j ts-forward"},
|
||||
{"nat/POSTROUTING", "-j ts-postrouting"},
|
||||
}
|
||||
for _, ipt := range []map[string][]string{n.ipt4, n.ipt6} {
|
||||
for _, r := range newRules {
|
||||
if err := insertRule(n, ipt, r.chain, r.rule); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *fakeIPTablesRunner) DelHooks(logf logger.Logf) error {
|
||||
delRules := []struct{ chain, rule string }{
|
||||
{"filter/INPUT", "-j ts-input"},
|
||||
{"filter/FORWARD", "-j ts-forward"},
|
||||
{"nat/POSTROUTING", "-j ts-postrouting"},
|
||||
}
|
||||
for _, ipt := range []map[string][]string{n.ipt4, n.ipt6} {
|
||||
for _, r := range delRules {
|
||||
if err := deleteRule(n, ipt, r.chain, r.rule); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *fakeIPTablesRunner) AddChains() error {
|
||||
for _, ipt := range []map[string][]string{n.ipt4, n.ipt6} {
|
||||
for _, chain := range []string{"filter/ts-input", "filter/ts-forward", "nat/ts-postrouting"} {
|
||||
ipt[chain] = nil
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *fakeIPTablesRunner) DelChains() error {
|
||||
for _, ipt := range []map[string][]string{n.ipt4, n.ipt6} {
|
||||
for chain := range ipt {
|
||||
if strings.HasPrefix(chain, "filter/ts-") || strings.HasPrefix(chain, "nat/ts-") {
|
||||
delete(ipt, chain)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *fakeIPTablesRunner) DelBase() error {
|
||||
for _, ipt := range []map[string][]string{n.ipt4, n.ipt6} {
|
||||
for _, chain := range []string{"filter/ts-input", "filter/ts-forward", "nat/ts-postrouting"} {
|
||||
ipt[chain] = nil
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *fakeIPTablesRunner) AddSNATRule() error {
|
||||
newRule := fmt.Sprintf("-m mark --mark %s/%s -j MASQUERADE", linuxfw.TailscaleSubnetRouteMark, linuxfw.TailscaleFwmarkMask)
|
||||
for _, ipt := range []map[string][]string{n.ipt4, n.ipt6} {
|
||||
if err := appendRule(n, ipt, "nat/ts-postrouting", newRule); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *fakeIPTablesRunner) DelSNATRule() error {
|
||||
delRule := fmt.Sprintf("-m mark --mark %s/%s -j MASQUERADE", linuxfw.TailscaleSubnetRouteMark, linuxfw.TailscaleFwmarkMask)
|
||||
for _, ipt := range []map[string][]string{n.ipt4, n.ipt6} {
|
||||
if err := deleteRule(n, ipt, "nat/ts-postrouting", delRule); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *fakeIPTablesRunner) HasIPV6() bool { return true }
|
||||
func (n *fakeIPTablesRunner) HasIPV6NAT() bool { return true }
|
||||
|
||||
// fakeOS implements commandRunner and provides v4 and v6
|
||||
// netfilterRunners, but captures changes without touching the OS.
|
||||
type fakeOS struct {
|
||||
t *testing.T
|
||||
up bool
|
||||
ips []string
|
||||
routes []string
|
||||
rules []string
|
||||
netfilter4 *fakeNetfilter
|
||||
netfilter6 *fakeNetfilter
|
||||
t *testing.T
|
||||
up bool
|
||||
ips []string
|
||||
routes []string
|
||||
rules []string
|
||||
//This test tests on the router level, so we will not bother
|
||||
//with using iptables or nftables, chose the simpler one.
|
||||
nfr netfilterRunner
|
||||
}
|
||||
|
||||
func NewFakeOS(t *testing.T) *fakeOS {
|
||||
return &fakeOS{
|
||||
t: t,
|
||||
netfilter4: newNetfilter(t),
|
||||
netfilter6: newNetfilter(t),
|
||||
t: t,
|
||||
nfr: newIPTablesRunner(t),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -516,23 +635,23 @@ func (o *fakeOS) String() string {
|
||||
}
|
||||
|
||||
var chains []string
|
||||
for chain := range o.netfilter4.n {
|
||||
for chain := range o.nfr.(*fakeIPTablesRunner).ipt4 {
|
||||
chains = append(chains, chain)
|
||||
}
|
||||
sort.Strings(chains)
|
||||
for _, chain := range chains {
|
||||
for _, rule := range o.netfilter4.n[chain] {
|
||||
for _, rule := range o.nfr.(*fakeIPTablesRunner).ipt4[chain] {
|
||||
fmt.Fprintf(&b, "v4/%s %s\n", chain, rule)
|
||||
}
|
||||
}
|
||||
|
||||
chains = nil
|
||||
for chain := range o.netfilter6.n {
|
||||
for chain := range o.nfr.(*fakeIPTablesRunner).ipt6 {
|
||||
chains = append(chains, chain)
|
||||
}
|
||||
sort.Strings(chains)
|
||||
for _, chain := range chains {
|
||||
for _, rule := range o.netfilter6.n[chain] {
|
||||
for _, rule := range o.nfr.(*fakeIPTablesRunner).ipt6[chain] {
|
||||
fmt.Fprintf(&b, "v6/%s %s\n", chain, rule)
|
||||
}
|
||||
}
|
||||
@@ -806,7 +925,7 @@ func TestDebugListRules(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCheckIPRuleSupportsV6(t *testing.T) {
|
||||
err := checkIPRuleSupportsV6(t.Logf)
|
||||
err := linuxfw.CheckIPRuleSupportsV6(t.Logf)
|
||||
if err != nil && os.Getuid() != 0 {
|
||||
t.Skipf("skipping, error when not root: %v", err)
|
||||
}
|
||||
|
Reference in New Issue
Block a user