diff --git a/util/linuxfw/iptables.go b/util/linuxfw/iptables.go index 4527fd076..3b57b9c67 100644 --- a/util/linuxfw/iptables.go +++ b/util/linuxfw/iptables.go @@ -7,7 +7,13 @@ package linuxfw import ( + "fmt" + "os/exec" + "strings" + "unicode" + "tailscale.com/types/logger" + "tailscale.com/util/multierr" ) // DebugNetfilter prints debug information about iptables rules to the @@ -21,9 +27,44 @@ func DebugIptables(logf logger.Logf) error { // system, ignoring the default "ACCEPT" rule present in the standard iptables // chains. // -// It only returns an error when the kernel returns an error (i.e. when a -// syscall fails); when there are no iptables rules, it is valid for this -// function to return 0, nil. +// It only returns an error when there is no iptables binary, or when iptables -S +// fails. In all other cases, it returns the number of non-default rules. func DetectIptables() (int, error) { - panic("unused") + // run "iptables -S" to get the list of rules using iptables + // exec.Command returns an error if the binary is not found + cmd := exec.Command("iptables", "-S") + output, err := cmd.Output() + ip6cmd := exec.Command("ip6tables", "-S") + ip6output, ip6err := ip6cmd.Output() + var allLines []string + outputStr := string(output) + lines := strings.Split(outputStr, "\n") + ip6outputStr := string(ip6output) + ip6lines := strings.Split(ip6outputStr, "\n") + switch { + case err == nil && ip6err == nil: + allLines = append(lines, ip6lines...) + case err == nil && ip6err != nil: + allLines = lines + case err != nil && ip6err == nil: + allLines = ip6lines + default: + return 0, ErrorFWModeNotSupported{ + Mode: FirewallModeIPTables, + Err: fmt.Errorf("iptables command run fail: %w", multierr.New(err, ip6err)), + } + } + + // count the number of non-default rules + count := 0 + for _, line := range allLines { + trimmedLine := strings.TrimLeftFunc(line, unicode.IsSpace) + if line != "" && strings.HasPrefix(trimmedLine, "-A") { + // if the line is not empty and starts with "-A", it is a rule appended not default + count++ + } + } + + // return the count of non-default rules + return count, nil } diff --git a/util/linuxfw/linuxfw.go b/util/linuxfw/linuxfw.go index 6ec152a4f..7a3bd02c0 100644 --- a/util/linuxfw/linuxfw.go +++ b/util/linuxfw/linuxfw.go @@ -29,6 +29,31 @@ Masq ) +type ErrorFWModeNotSupported struct { + Mode FirewallMode + Err error +} + +func (e ErrorFWModeNotSupported) Error() string { + return fmt.Sprintf("firewall mode %q not supported: %v", e.Mode, e.Err) +} + +func (e ErrorFWModeNotSupported) Is(target error) bool { + _, ok := target.(ErrorFWModeNotSupported) + return ok +} + +func (e ErrorFWModeNotSupported) Unwrap() error { + return e.Err +} + +type FirewallMode string + +const ( + FirewallModeIPTables FirewallMode = "iptables" + FirewallModeNfTables FirewallMode = "nftables" +) + // The following bits are added to packet marks for Tailscale use. // // We tried to pick bits sufficiently out of the way that it's diff --git a/util/linuxfw/nftables.go b/util/linuxfw/nftables.go index ce0b022fa..6a462a890 100644 --- a/util/linuxfw/nftables.go +++ b/util/linuxfw/nftables.go @@ -107,12 +107,18 @@ func DebugNetfilter(logf logger.Logf) error { func DetectNetfilter() (int, error) { conn, err := nftables.New() if err != nil { - return 0, err + return 0, ErrorFWModeNotSupported{ + Mode: FirewallModeNfTables, + Err: err, + } } chains, err := conn.ListChains() if err != nil { - return 0, fmt.Errorf("cannot list chains: %w", err) + return 0, ErrorFWModeNotSupported{ + Mode: FirewallModeNfTables, + Err: fmt.Errorf("cannot list chains: %w", err), + } } var validRules int diff --git a/wgengine/router/router_linux.go b/wgengine/router/router_linux.go index 6b723c845..074e61a22 100644 --- a/wgengine/router/router_linux.go +++ b/wgengine/router/router_linux.go @@ -54,24 +54,95 @@ type netfilterRunner interface { HasIPV6NAT() bool } +// tableDetector abstracts helpers to detect the firewall mode. +// It is implemented for testing purposes. +type tableDetector interface { + iptDetect() (int, error) + nftDetect() (int, error) +} + +type linuxFWDetector struct{} + +// iptDetect returns the number of iptables rules in the current namespace. +func (l *linuxFWDetector) iptDetect() (int, error) { + return linuxfw.DetectIptables() +} + +// nftDetect returns the number of nftables rules in the current namespace. +func (l *linuxFWDetector) nftDetect() (int, error) { + return linuxfw.DetectNetfilter() +} + +// chooseFireWallMode returns the firewall mode to use based on the +// environment and the system's capabilities. +func chooseFireWallMode(logf logger.Logf, det tableDetector) (linuxfw.FirewallMode, error) { + iptAva, nftAva := true, true + iptRuleCount, err := det.iptDetect() + if err != nil { + logf("router: detect iptables rule: %v", err) + iptAva = false + } + nftRuleCount, err := det.nftDetect() + if err != nil { + logf("router: detect nftables rule: %v", err) + nftAva = false + } + logf("router: nftables rule count: %d, iptables rule count: %d", nftRuleCount, iptRuleCount) + switch { + case envknob.String("TS_DEBUG_FIREWALL_MODE") == "nftables": + // TODO(KevinLiang10): Updates to a flag + logf("router: envknob TS_DEBUG_FIREWALL_MODE=nftables set") + return linuxfw.FirewallModeNfTables, nil + case envknob.String("TS_DEBUG_FIREWALL_MODE") == "iptables": + logf("router: envknob TS_DEBUG_FIREWALL_MODE=iptables set") + return linuxfw.FirewallModeIPTables, nil + case nftRuleCount > 0 && iptRuleCount == 0: + logf("router: nftables is currently in use") + return linuxfw.FirewallModeNfTables, nil + case iptRuleCount > 0 && nftRuleCount == 0: + logf("router: iptables is currently in use") + return linuxfw.FirewallModeIPTables, nil + case nftAva: + // if both iptables and nftables are available but + // neither/both are currently used, use nftables. + logf("router: nftables is available") + return linuxfw.FirewallModeNfTables, nil + case iptAva: + logf("router: iptables is available") + return linuxfw.FirewallModeIPTables, nil + default: + // if neither iptables nor nftables are available, + // this is an error that shouldn't happen. + return "", errors.New("router: neither iptables nor nftables are available") + } +} + // newNetfilterRunner creates a netfilterRunner using either nftables or iptables. // As nftables is still experimental, iptables will be used unless TS_DEBUG_USE_NETLINK_NFTABLES is set. func newNetfilterRunner(logf logger.Logf) (netfilterRunner, error) { + tableDetector := &linuxFWDetector{} + mode, err := chooseFireWallMode(logf, tableDetector) + if err != nil { + return nil, fmt.Errorf("choosing firewall mode: %w", err) + } var nfr netfilterRunner - var err error - if envknob.Bool("TS_DEBUG_USE_NETLINK_NFTABLES") { - logf("router: using nftables") - nfr, err = linuxfw.NewNfTablesRunner(logf) - if err != nil { - return nil, err - } - } else { + switch mode { + case linuxfw.FirewallModeIPTables: logf("router: using iptables") nfr, err = linuxfw.NewIPTablesRunner(logf) if err != nil { return nil, err } + case linuxfw.FirewallModeNfTables: + logf("router: using nftables") + nfr, err = linuxfw.NewNfTablesRunner(logf) + if err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("unknown firewall mode: %v", mode) } + return nfr, nil } diff --git a/wgengine/router/router_linux_test.go b/wgengine/router/router_linux_test.go index e7998f9b9..32d0272f7 100644 --- a/wgengine/router/router_linux_test.go +++ b/wgengine/router/router_linux_test.go @@ -1063,3 +1063,63 @@ func adjustFwmask(t *testing.T, s string) string { return fwmaskAdjustRe.ReplaceAllString(s, "$1") } + +type testFWDetector struct { + iptRuleCount, nftRuleCount int + iptErr, nftErr error +} + +func (t *testFWDetector) iptDetect() (int, error) { + return t.iptRuleCount, t.iptErr +} + +func (t *testFWDetector) nftDetect() (int, error) { + return t.nftRuleCount, t.nftErr +} + +func TestChooseFireWallMode(t *testing.T) { + tests := []struct { + name string + det *testFWDetector + want linuxfw.FirewallMode + }{ + { + name: "using iptables legacy", + det: &testFWDetector{iptRuleCount: 1}, + want: linuxfw.FirewallModeIPTables, + }, + { + name: "using nftables", + det: &testFWDetector{nftRuleCount: 1}, + want: linuxfw.FirewallModeNfTables, + }, + { + name: "using both iptables and nftables", + det: &testFWDetector{iptRuleCount: 2, nftRuleCount: 2}, + want: linuxfw.FirewallModeNfTables, + }, + { + name: "not using any firewall, both available", + det: &testFWDetector{}, + want: linuxfw.FirewallModeNfTables, + }, + { + name: "not using any firewall, iptables available only", + det: &testFWDetector{iptRuleCount: 1, nftErr: errors.New("nft error")}, + want: linuxfw.FirewallModeIPTables, + }, + { + name: "not using any firewall, nftables available only", + det: &testFWDetector{iptErr: errors.New("iptables error"), nftRuleCount: 1}, + want: linuxfw.FirewallModeNfTables, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, _ := chooseFireWallMode(t.Logf, tt.det) + if got != tt.want { + t.Errorf("chooseFireWallMode() = %v, want %v", got, tt.want) + } + }) + } +}