diff --git a/util/linuxfw/fake_netfilter.go b/util/linuxfw/fake_netfilter.go new file mode 100644 index 000000000..329c3a213 --- /dev/null +++ b/util/linuxfw/fake_netfilter.go @@ -0,0 +1,95 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux + +package linuxfw + +import ( + "net/netip" + + "tailscale.com/types/logger" +) + +// FakeNetfilterRunner is a fake netfilter runner for tests. +type FakeNetfilterRunner struct { + // services is a map that tracks the firewall rules added/deleted via + // EnsureDNATRuleForSvc/DeleteDNATRuleForSvc. + services map[string]struct { + VIPServiceIP netip.Addr + ClusterIP netip.Addr + } +} + +// NewFakeNetfilterRunner creates a new FakeNetfilterRunner. +func NewFakeNetfilterRunner() *FakeNetfilterRunner { + return &FakeNetfilterRunner{ + services: make(map[string]struct { + VIPServiceIP netip.Addr + ClusterIP netip.Addr + }), + } +} + +func (f *FakeNetfilterRunner) EnsureDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error { + f.services[svcName] = struct { + VIPServiceIP netip.Addr + ClusterIP netip.Addr + }{origDst, dst} + return nil +} + +func (f *FakeNetfilterRunner) DeleteDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error { + delete(f.services, svcName) + return nil +} + +func (f *FakeNetfilterRunner) GetServiceState() map[string]struct { + VIPServiceIP netip.Addr + ClusterIP netip.Addr +} { + return f.services +} + +func (f *FakeNetfilterRunner) HasIPV6() bool { + return true +} + +func (f *FakeNetfilterRunner) HasIPV6Filter() bool { + return true +} + +func (f *FakeNetfilterRunner) HasIPV6NAT() bool { + return true +} + +func (f *FakeNetfilterRunner) AddBase(tunname string) error { return nil } +func (f *FakeNetfilterRunner) DelBase() error { return nil } +func (f *FakeNetfilterRunner) AddChains() error { return nil } +func (f *FakeNetfilterRunner) DelChains() error { return nil } +func (f *FakeNetfilterRunner) AddHooks() error { return nil } +func (f *FakeNetfilterRunner) DelHooks(logf logger.Logf) error { return nil } +func (f *FakeNetfilterRunner) AddSNATRule() error { return nil } +func (f *FakeNetfilterRunner) DelSNATRule() error { return nil } +func (f *FakeNetfilterRunner) AddStatefulRule(tunname string) error { return nil } +func (f *FakeNetfilterRunner) DelStatefulRule(tunname string) error { return nil } +func (f *FakeNetfilterRunner) AddLoopbackRule(addr netip.Addr) error { return nil } +func (f *FakeNetfilterRunner) DelLoopbackRule(addr netip.Addr) error { return nil } +func (f *FakeNetfilterRunner) AddDNATRule(origDst, dst netip.Addr) error { return nil } +func (f *FakeNetfilterRunner) DNATWithLoadBalancer(origDst netip.Addr, dsts []netip.Addr) error { + return nil +} +func (f *FakeNetfilterRunner) EnsureSNATForDst(src, dst netip.Addr) error { return nil } +func (f *FakeNetfilterRunner) DNATNonTailscaleTraffic(tun string, dst netip.Addr) error { return nil } +func (f *FakeNetfilterRunner) ClampMSSToPMTU(tun string, addr netip.Addr) error { return nil } +func (f *FakeNetfilterRunner) AddMagicsockPortRule(port uint16, network string) error { return nil } +func (f *FakeNetfilterRunner) DelMagicsockPortRule(port uint16, network string) error { return nil } +func (f *FakeNetfilterRunner) DeletePortMapRuleForSvc(svc, tun string, targetIP netip.Addr, pm PortMap) error { + return nil +} +func (f *FakeNetfilterRunner) DeleteSvc(svc, tun string, targetIPs []netip.Addr, pms []PortMap) error { + return nil +} +func (f *FakeNetfilterRunner) EnsurePortMapRuleForSvc(svc, tun string, targetIP netip.Addr, pm PortMap) error { + return nil +} diff --git a/util/linuxfw/iptables_for_svcs.go b/util/linuxfw/iptables_for_svcs.go index 8e0f5d48d..2cd8716e4 100644 --- a/util/linuxfw/iptables_for_svcs.go +++ b/util/linuxfw/iptables_for_svcs.go @@ -13,6 +13,7 @@ import ( // This file contains functionality to insert portmapping rules for a 'service'. // These are currently only used by the Kubernetes operator proxies. // An iptables rule for such a service contains a comment with the service name. +// A 'service' corresponds to a VIPService as used by the Kubernetes operator. // EnsurePortMapRuleForSvc adds a prerouting rule that forwards traffic received // on match port and NOT on the provided interface to target IP and target port. @@ -24,10 +25,10 @@ func (i *iptablesRunner) EnsurePortMapRuleForSvc(svc, tun string, targetIP netip if err != nil { return fmt.Errorf("error checking if rule exists: %w", err) } - if !exists { - return table.Append("nat", "PREROUTING", args...) + if exists { + return nil } - return nil + return table.Append("nat", "PREROUTING", args...) } // DeleteMapRuleForSvc constructs a prerouting rule as would be created by @@ -40,10 +41,41 @@ func (i *iptablesRunner) DeletePortMapRuleForSvc(svc, excludeI string, targetIP if err != nil { return fmt.Errorf("error checking if rule exists: %w", err) } - if exists { - return table.Delete("nat", "PREROUTING", args...) + if !exists { + return nil } - return nil + return table.Delete("nat", "PREROUTING", args...) +} + +// EnsureDNATRuleForSvc adds a DNAT rule that forwards traffic from the +// VIPService IP address to a local address. This is used by the Kubernetes +// operator's network layer proxies to forward tailnet traffic for VIPServices +// to Kubernetes Services. +func (i *iptablesRunner) EnsureDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error { + table := i.getIPTByAddr(dst) + args := argsForIngressRule(svcName, origDst, dst) + exists, err := table.Exists("nat", "PREROUTING", args...) + if err != nil { + return fmt.Errorf("error checking if rule exists: %w", err) + } + if exists { + return nil + } + return table.Append("nat", "PREROUTING", args...) +} + +// DeleteDNATRuleForSvc deletes a DNAT rule created by EnsureDNATRuleForSvc. +func (i *iptablesRunner) DeleteDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error { + table := i.getIPTByAddr(dst) + args := argsForIngressRule(svcName, origDst, dst) + exists, err := table.Exists("nat", "PREROUTING", args...) + if err != nil { + return fmt.Errorf("error checking if rule exists: %w", err) + } + if !exists { + return nil + } + return table.Delete("nat", "PREROUTING", args...) } // DeleteSvc constructs all possible rules that would have been created by @@ -72,8 +104,24 @@ func argsForPortMapRule(svc, excludeI string, targetIP netip.Addr, pm PortMap) [ } } +func argsForIngressRule(svcName string, origDst, targetIP netip.Addr) []string { + c := commentForIngressSvc(svcName, origDst, targetIP) + return []string{ + "--destination", origDst.String(), + "-m", "comment", "--comment", c, + "-j", "DNAT", + "--to-destination", targetIP.String(), + } +} + // commentForSvc generates a comment to be added to an iptables DNAT rule for a // service. This is for iptables debugging/readability purposes only. func commentForSvc(svc string, pm PortMap) string { return fmt.Sprintf("%s:%s:%d -> %s:%d", svc, pm.Protocol, pm.MatchPort, pm.Protocol, pm.TargetPort) } + +// commentForIngressSvc generates a comment to be added to an iptables DNAT rule for a +// service. This is for iptables debugging/readability purposes only. +func commentForIngressSvc(svc string, vip, clusterIP netip.Addr) string { + return fmt.Sprintf("svc: %s, %s -> %s", svc, vip.String(), clusterIP.String()) +} diff --git a/util/linuxfw/iptables_for_svcs_test.go b/util/linuxfw/iptables_for_svcs_test.go index 99b2f517f..c3c1b1f65 100644 --- a/util/linuxfw/iptables_for_svcs_test.go +++ b/util/linuxfw/iptables_for_svcs_test.go @@ -153,6 +153,135 @@ func Test_iptablesRunner_DeleteSvc(t *testing.T) { svcMustExist(t, "svc2", map[string][]string{v4Addr.String(): s2R1, v6Addr.String(): s2R2}, iptr) } +func Test_iptablesRunner_EnsureDNATRuleForSvc(t *testing.T) { + v4OrigDst := netip.MustParseAddr("10.0.0.1") + v4Target := netip.MustParseAddr("10.0.0.2") + v6OrigDst := netip.MustParseAddr("fd7a:115c:a1e0::1") + v6Target := netip.MustParseAddr("fd7a:115c:a1e0::2") + v4Rule := argsForIngressRule("svc:test", v4OrigDst, v4Target) + + tests := []struct { + name string + svcName string + origDst netip.Addr + targetIP netip.Addr + precreateSvcRules [][]string + }{ + { + name: "dnat_for_ipv4", + svcName: "svc:test", + origDst: v4OrigDst, + targetIP: v4Target, + }, + { + name: "dnat_for_ipv6", + svcName: "svc:test-2", + origDst: v6OrigDst, + targetIP: v6Target, + }, + { + name: "add_existing_rule", + svcName: "svc:test", + origDst: v4OrigDst, + targetIP: v4Target, + precreateSvcRules: [][]string{v4Rule}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + iptr := NewFakeIPTablesRunner() + table := iptr.getIPTByAddr(tt.targetIP) + for _, ruleset := range tt.precreateSvcRules { + mustPrecreateDNATRule(t, ruleset, table) + } + if err := iptr.EnsureDNATRuleForSvc(tt.svcName, tt.origDst, tt.targetIP); err != nil { + t.Errorf("[unexpected error] iptablesRunner.EnsureDNATRuleForSvc() = %v", err) + } + args := argsForIngressRule(tt.svcName, tt.origDst, tt.targetIP) + exists, err := table.Exists("nat", "PREROUTING", args...) + if err != nil { + t.Fatalf("error checking if rule exists: %v", err) + } + if !exists { + t.Errorf("expected rule was not created") + } + }) + } +} + +func Test_iptablesRunner_DeleteDNATRuleForSvc(t *testing.T) { + v4OrigDst := netip.MustParseAddr("10.0.0.1") + v4Target := netip.MustParseAddr("10.0.0.2") + v6OrigDst := netip.MustParseAddr("fd7a:115c:a1e0::1") + v6Target := netip.MustParseAddr("fd7a:115c:a1e0::2") + v4Rule := argsForIngressRule("svc:test", v4OrigDst, v4Target) + v6Rule := argsForIngressRule("svc:test", v6OrigDst, v6Target) + + tests := []struct { + name string + svcName string + origDst netip.Addr + targetIP netip.Addr + precreateSvcRules [][]string + }{ + { + name: "multiple_rules_ipv4_deleted", + svcName: "svc:test", + origDst: v4OrigDst, + targetIP: v4Target, + precreateSvcRules: [][]string{v4Rule, v6Rule}, + }, + { + name: "multiple_rules_ipv6_deleted", + svcName: "svc:test", + origDst: v6OrigDst, + targetIP: v6Target, + precreateSvcRules: [][]string{v4Rule, v6Rule}, + }, + { + name: "non-existent_rule_deleted", + svcName: "svc:test", + origDst: v4OrigDst, + targetIP: v4Target, + precreateSvcRules: [][]string{v6Rule}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + iptr := NewFakeIPTablesRunner() + table := iptr.getIPTByAddr(tt.targetIP) + for _, ruleset := range tt.precreateSvcRules { + mustPrecreateDNATRule(t, ruleset, table) + } + if err := iptr.DeleteDNATRuleForSvc(tt.svcName, tt.origDst, tt.targetIP); err != nil { + t.Errorf("iptablesRunner.DeleteDNATRuleForSvc() errored: %v ", err) + } + deletedRule := argsForIngressRule(tt.svcName, tt.origDst, tt.targetIP) + exists, err := table.Exists("nat", "PREROUTING", deletedRule...) + if err != nil { + t.Fatalf("error verifying that rule does not exist after deletion: %v", err) + } + if exists { + t.Errorf("DNAT rule exists after deletion") + } + }) + } +} + +func mustPrecreateDNATRule(t *testing.T, rules []string, table iptablesInterface) { + t.Helper() + exists, err := table.Exists("nat", "PREROUTING", rules...) + if err != nil { + t.Fatalf("error ensuring that nat PREROUTING table exists: %v", err) + } + if exists { + return + } + if err := table.Append("nat", "PREROUTING", rules...); err != nil { + t.Fatalf("error precreating DNAT rule: %v", err) + } +} + func svcMustExist(t *testing.T, svcName string, rules map[string][]string, iptr *iptablesRunner) { t.Helper() for dst, ruleset := range rules { diff --git a/util/linuxfw/nftables_for_svcs.go b/util/linuxfw/nftables_for_svcs.go index 130585b22..474b98086 100644 --- a/util/linuxfw/nftables_for_svcs.go +++ b/util/linuxfw/nftables_for_svcs.go @@ -119,6 +119,63 @@ func (n *nftablesRunner) DeleteSvc(svc, tun string, targetIPs []netip.Addr, pm [ return n.conn.Flush() } +// EnsureDNATRuleForSvc adds a DNAT rule that forwards traffic from the +// VIPService IP address to a local address. This is used by the Kubernetes +// operator's network layer proxies to forward tailnet traffic for VIPServices +// to Kubernetes Services. +func (n *nftablesRunner) EnsureDNATRuleForSvc(svc string, origDst, dst netip.Addr) error { + t, ch, err := n.ensurePreroutingChain(origDst) + if err != nil { + return fmt.Errorf("error ensuring chain for %s: %w", svc, err) + } + meta := svcRuleMeta(svc, origDst, dst) + rule, err := n.findRuleByMetadata(t, ch, meta) + if err != nil { + return fmt.Errorf("error looking up rule: %w", err) + } + if rule != nil { + return nil + } + rule = dnatRuleForChain(t, ch, origDst, dst, meta) + n.conn.InsertRule(rule) + return n.conn.Flush() +} + +// DeleteDNATRuleForSvc deletes a DNAT rule created by EnsureDNATRuleForSvc. +// We use the metadata attached to the rule to look it up. +func (n *nftablesRunner) DeleteDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error { + table, err := n.getNFTByAddr(origDst) + if err != nil { + return fmt.Errorf("error setting up nftables for IP family of %s: %w", origDst, err) + } + t, err := getTableIfExists(n.conn, table.Proto, "nat") + if err != nil { + return fmt.Errorf("error checking if nat table exists: %w", err) + } + if t == nil { + return nil + } + ch, err := getChainFromTable(n.conn, t, "PREROUTING") + if errors.Is(err, errorChainNotFound{tableName: "nat", chainName: "PREROUTING"}) { + return nil + } + if err != nil { + return fmt.Errorf("error checking if chain PREROUTING exists: %w", err) + } + meta := svcRuleMeta(svcName, origDst, dst) + rule, err := n.findRuleByMetadata(t, ch, meta) + if err != nil { + return fmt.Errorf("error checking if rule exists: %w", err) + } + if rule == nil { + return nil + } + if err := n.conn.DelRule(rule); err != nil { + return fmt.Errorf("error deleting rule: %w", err) + } + return n.conn.Flush() +} + func portMapRule(t *nftables.Table, ch *nftables.Chain, tun string, targetIP netip.Addr, matchPort, targetPort uint16, proto uint8, meta []byte) *nftables.Rule { var fam uint32 if targetIP.Is4() { @@ -243,3 +300,10 @@ func protoFromString(s string) (uint8, error) { return 0, fmt.Errorf("unrecognized protocol: %q", s) } } + +// svcRuleMeta generates metadata for a rule. +// This metadata can then be used to find the rule. +// https://github.com/google/nftables/issues/48 +func svcRuleMeta(svcName string, origDst, dst netip.Addr) []byte { + return []byte(fmt.Sprintf("svc:%s,VIP:%s,ClusterIP:%s", svcName, origDst.String(), dst.String())) +} diff --git a/util/linuxfw/nftables_for_svcs_test.go b/util/linuxfw/nftables_for_svcs_test.go index d2df6e4bd..73472ce20 100644 --- a/util/linuxfw/nftables_for_svcs_test.go +++ b/util/linuxfw/nftables_for_svcs_test.go @@ -14,8 +14,9 @@ import ( // This test creates a temporary network namespace for the nftables rules being // set up, so it needs to run in a privileged mode. Locally it needs to be run -// by root, else it will be silently skipped. In CI it runs in a privileged -// container. +// by root, else it will be silently skipped. +// sudo go test -v -run Test_nftablesRunner_EnsurePortMapRuleForSvc ./util/linuxfw/... +// In CI it runs in a privileged container. func Test_nftablesRunner_EnsurePortMapRuleForSvc(t *testing.T) { conn := newSysConn(t) runner := newFakeNftablesRunnerWithConn(t, conn, true) @@ -23,51 +24,215 @@ func Test_nftablesRunner_EnsurePortMapRuleForSvc(t *testing.T) { pmTCP := PortMap{MatchPort: 4003, TargetPort: 80, Protocol: "TCP"} pmTCP1 := PortMap{MatchPort: 4004, TargetPort: 443, Protocol: "TCP"} - // Create a rule for service 'foo' to forward TCP traffic to IPv4 endpoint - runner.EnsurePortMapRuleForSvc("foo", "tailscale0", ipv4, pmTCP) + // Create a rule for service 'svc:foo' to forward TCP traffic to IPv4 endpoint + runner.EnsurePortMapRuleForSvc("svc:foo", "tailscale0", ipv4, pmTCP) svcChains(t, 1, conn) - chainRuleCount(t, "foo", 1, conn, nftables.TableFamilyIPv4) - checkPortMapRule(t, "foo", ipv4, pmTCP, runner, nftables.TableFamilyIPv4) + chainRuleCount(t, "svc:foo", 1, conn, nftables.TableFamilyIPv4) + checkPortMapRule(t, "svc:foo", ipv4, pmTCP, runner, nftables.TableFamilyIPv4) - // Create another rule for service 'foo' to forward TCP traffic to the + // Create another rule for service 'svc:foo' to forward TCP traffic to the // same IPv4 endpoint, but to a different port. - runner.EnsurePortMapRuleForSvc("foo", "tailscale0", ipv4, pmTCP1) + runner.EnsurePortMapRuleForSvc("svc:foo", "tailscale0", ipv4, pmTCP1) svcChains(t, 1, conn) - chainRuleCount(t, "foo", 2, conn, nftables.TableFamilyIPv4) - checkPortMapRule(t, "foo", ipv4, pmTCP1, runner, nftables.TableFamilyIPv4) + chainRuleCount(t, "svc:foo", 2, conn, nftables.TableFamilyIPv4) + checkPortMapRule(t, "svc:foo", ipv4, pmTCP1, runner, nftables.TableFamilyIPv4) - // Create a rule for service 'foo' to forward TCP traffic to an IPv6 endpoint - runner.EnsurePortMapRuleForSvc("foo", "tailscale0", ipv6, pmTCP) + // Create a rule for service 'svc:foo' to forward TCP traffic to an IPv6 endpoint + runner.EnsurePortMapRuleForSvc("svc:foo", "tailscale0", ipv6, pmTCP) svcChains(t, 2, conn) - chainRuleCount(t, "foo", 1, conn, nftables.TableFamilyIPv6) - checkPortMapRule(t, "foo", ipv6, pmTCP, runner, nftables.TableFamilyIPv6) + chainRuleCount(t, "svc:foo", 1, conn, nftables.TableFamilyIPv6) + checkPortMapRule(t, "svc:foo", ipv6, pmTCP, runner, nftables.TableFamilyIPv6) - // Create a rule for service 'bar' to forward TCP traffic to IPv4 endpoint - runner.EnsurePortMapRuleForSvc("bar", "tailscale0", ipv4, pmTCP) + // Create a rule for service 'svc:bar' to forward TCP traffic to IPv4 endpoint + runner.EnsurePortMapRuleForSvc("svc:bar", "tailscale0", ipv4, pmTCP) svcChains(t, 3, conn) - chainRuleCount(t, "bar", 1, conn, nftables.TableFamilyIPv4) - checkPortMapRule(t, "bar", ipv4, pmTCP, runner, nftables.TableFamilyIPv4) + chainRuleCount(t, "svc:bar", 1, conn, nftables.TableFamilyIPv4) + checkPortMapRule(t, "svc:bar", ipv4, pmTCP, runner, nftables.TableFamilyIPv4) - // Create a rule for service 'bar' to forward TCP traffic to an IPv6 endpoint - runner.EnsurePortMapRuleForSvc("bar", "tailscale0", ipv6, pmTCP) + // Create a rule for service 'svc:bar' to forward TCP traffic to an IPv6 endpoint + runner.EnsurePortMapRuleForSvc("svc:bar", "tailscale0", ipv6, pmTCP) svcChains(t, 4, conn) - chainRuleCount(t, "bar", 1, conn, nftables.TableFamilyIPv6) - checkPortMapRule(t, "bar", ipv6, pmTCP, runner, nftables.TableFamilyIPv6) + chainRuleCount(t, "svc:bar", 1, conn, nftables.TableFamilyIPv6) + checkPortMapRule(t, "svc:bar", ipv6, pmTCP, runner, nftables.TableFamilyIPv6) - // Delete service bar - runner.DeleteSvc("bar", "tailscale0", []netip.Addr{ipv4, ipv6}, []PortMap{pmTCP}) + // Delete service svc:bar + runner.DeleteSvc("svc:bar", "tailscale0", []netip.Addr{ipv4, ipv6}, []PortMap{pmTCP}) svcChains(t, 2, conn) - // Delete a rule from service foo - runner.DeletePortMapRuleForSvc("foo", "tailscale0", ipv4, pmTCP) + // Delete a rule from service svc:foo + runner.DeletePortMapRuleForSvc("svc:foo", "tailscale0", ipv4, pmTCP) svcChains(t, 2, conn) - chainRuleCount(t, "foo", 1, conn, nftables.TableFamilyIPv4) + chainRuleCount(t, "svc:foo", 1, conn, nftables.TableFamilyIPv4) - // Delete service foo - runner.DeleteSvc("foo", "tailscale0", []netip.Addr{ipv4, ipv6}, []PortMap{pmTCP, pmTCP1}) + // Delete service svc:foo + runner.DeleteSvc("svc:foo", "tailscale0", []netip.Addr{ipv4, ipv6}, []PortMap{pmTCP, pmTCP1}) svcChains(t, 0, conn) } +func Test_nftablesRunner_EnsureDNATRuleForSvc(t *testing.T) { + conn := newSysConn(t) + runner := newFakeNftablesRunnerWithConn(t, conn, true) + + // Test IPv4 DNAT rule + ipv4OrigDst := netip.MustParseAddr("10.0.0.1") + ipv4Target := netip.MustParseAddr("10.0.0.2") + + // Create DNAT rule for service 'svc:foo' to forward IPv4 traffic + err := runner.EnsureDNATRuleForSvc("svc:foo", ipv4OrigDst, ipv4Target) + if err != nil { + t.Fatalf("error creating IPv4 DNAT rule: %v", err) + } + checkDNATRule(t, "svc:foo", ipv4OrigDst, ipv4Target, runner, nftables.TableFamilyIPv4) + + // Test IPv6 DNAT rule + ipv6OrigDst := netip.MustParseAddr("fd7a:115c:a1e0::1") + ipv6Target := netip.MustParseAddr("fd7a:115c:a1e0::2") + + // Create DNAT rule for service 'svc:foo' to forward IPv6 traffic + err = runner.EnsureDNATRuleForSvc("svc:foo", ipv6OrigDst, ipv6Target) + if err != nil { + t.Fatalf("error creating IPv6 DNAT rule: %v", err) + } + checkDNATRule(t, "svc:foo", ipv6OrigDst, ipv6Target, runner, nftables.TableFamilyIPv6) + + // Test creating rule for another service + err = runner.EnsureDNATRuleForSvc("svc:bar", ipv4OrigDst, ipv4Target) + if err != nil { + t.Fatalf("error creating DNAT rule for service 'svc:bar': %v", err) + } + checkDNATRule(t, "svc:bar", ipv4OrigDst, ipv4Target, runner, nftables.TableFamilyIPv4) +} + +func Test_nftablesRunner_DeleteDNATRuleForSvc(t *testing.T) { + conn := newSysConn(t) + runner := newFakeNftablesRunnerWithConn(t, conn, true) + + // Test IPv4 DNAT rule deletion + ipv4OrigDst := netip.MustParseAddr("10.0.0.1") + ipv4Target := netip.MustParseAddr("10.0.0.2") + + // Create and then delete IPv4 DNAT rule + err := runner.EnsureDNATRuleForSvc("svc:foo", ipv4OrigDst, ipv4Target) + if err != nil { + t.Fatalf("error creating IPv4 DNAT rule: %v", err) + } + + // Verify rule exists before deletion + table, err := runner.getNFTByAddr(ipv4OrigDst) + if err != nil { + t.Fatalf("error getting table: %v", err) + } + nftTable, err := getTableIfExists(runner.conn, table.Proto, "nat") + if err != nil { + t.Fatalf("error getting nat table: %v", err) + } + ch, err := getChainFromTable(runner.conn, nftTable, "PREROUTING") + if err != nil { + t.Fatalf("error getting PREROUTING chain: %v", err) + } + meta := svcRuleMeta("svc:foo", ipv4OrigDst, ipv4Target) + rule, err := runner.findRuleByMetadata(nftTable, ch, meta) + if err != nil { + t.Fatalf("error checking if rule exists: %v", err) + } + if rule == nil { + t.Fatal("rule does not exist before deletion") + } + + err = runner.DeleteDNATRuleForSvc("svc:foo", ipv4OrigDst, ipv4Target) + if err != nil { + t.Fatalf("error deleting IPv4 DNAT rule: %v", err) + } + + // Verify rule is deleted + rule, err = runner.findRuleByMetadata(nftTable, ch, meta) + if err != nil { + t.Fatalf("error checking if rule exists: %v", err) + } + if rule != nil { + t.Fatal("rule still exists after deletion") + } + + // Test IPv6 DNAT rule deletion + ipv6OrigDst := netip.MustParseAddr("fd7a:115c:a1e0::1") + ipv6Target := netip.MustParseAddr("fd7a:115c:a1e0::2") + + // Create and then delete IPv6 DNAT rule + err = runner.EnsureDNATRuleForSvc("svc:foo", ipv6OrigDst, ipv6Target) + if err != nil { + t.Fatalf("error creating IPv6 DNAT rule: %v", err) + } + + // Verify rule exists before deletion + table, err = runner.getNFTByAddr(ipv6OrigDst) + if err != nil { + t.Fatalf("error getting table: %v", err) + } + nftTable, err = getTableIfExists(runner.conn, table.Proto, "nat") + if err != nil { + t.Fatalf("error getting nat table: %v", err) + } + ch, err = getChainFromTable(runner.conn, nftTable, "PREROUTING") + if err != nil { + t.Fatalf("error getting PREROUTING chain: %v", err) + } + meta = svcRuleMeta("svc:foo", ipv6OrigDst, ipv6Target) + rule, err = runner.findRuleByMetadata(nftTable, ch, meta) + if err != nil { + t.Fatalf("error checking if rule exists: %v", err) + } + if rule == nil { + t.Fatal("rule does not exist before deletion") + } + + err = runner.DeleteDNATRuleForSvc("svc:foo", ipv6OrigDst, ipv6Target) + if err != nil { + t.Fatalf("error deleting IPv6 DNAT rule: %v", err) + } + + // Verify rule is deleted + rule, err = runner.findRuleByMetadata(nftTable, ch, meta) + if err != nil { + t.Fatalf("error checking if rule exists: %v", err) + } + if rule != nil { + t.Fatal("rule still exists after deletion") + } +} + +// checkDNATRule verifies that a DNAT rule exists for the given service, original destination, and target IP. +func checkDNATRule(t *testing.T, svc string, origDst, targetIP netip.Addr, runner *nftablesRunner, fam nftables.TableFamily) { + t.Helper() + table, err := runner.getNFTByAddr(origDst) + if err != nil { + t.Fatalf("error getting table: %v", err) + } + nftTable, err := getTableIfExists(runner.conn, table.Proto, "nat") + if err != nil { + t.Fatalf("error getting nat table: %v", err) + } + if nftTable == nil { + t.Fatal("nat table not found") + } + + ch, err := getChainFromTable(runner.conn, nftTable, "PREROUTING") + if err != nil { + t.Fatalf("error getting PREROUTING chain: %v", err) + } + if ch == nil { + t.Fatal("PREROUTING chain not found") + } + + meta := svcRuleMeta(svc, origDst, targetIP) + rule, err := runner.findRuleByMetadata(nftTable, ch, meta) + if err != nil { + t.Fatalf("error checking if rule exists: %v", err) + } + if rule == nil { + t.Fatal("DNAT rule not found") + } +} + // svcChains verifies that the expected number of chains exist (for either IP // family) and that each of them is configured as NAT prerouting chain. func svcChains(t *testing.T, wantCount int, conn *nftables.Conn) { diff --git a/util/linuxfw/nftables_runner.go b/util/linuxfw/nftables_runner.go index b87298c61..faa02f7c7 100644 --- a/util/linuxfw/nftables_runner.go +++ b/util/linuxfw/nftables_runner.go @@ -107,6 +107,12 @@ func (n *nftablesRunner) AddDNATRule(origDst netip.Addr, dst netip.Addr) error { if err != nil { return err } + rule := dnatRuleForChain(nat, preroutingCh, origDst, dst, nil) + n.conn.InsertRule(rule) + return n.conn.Flush() +} + +func dnatRuleForChain(t *nftables.Table, ch *nftables.Chain, origDst, dst netip.Addr, meta []byte) *nftables.Rule { var daddrOffset, fam, dadderLen uint32 if origDst.Is4() { daddrOffset = 16 @@ -117,9 +123,9 @@ func (n *nftablesRunner) AddDNATRule(origDst netip.Addr, dst netip.Addr) error { dadderLen = 16 fam = unix.NFPROTO_IPV6 } - dnatRule := &nftables.Rule{ - Table: nat, - Chain: preroutingCh, + rule := &nftables.Rule{ + Table: t, + Chain: ch, Exprs: []expr.Any{ &expr.Payload{ DestRegister: 1, @@ -143,8 +149,10 @@ func (n *nftablesRunner) AddDNATRule(origDst netip.Addr, dst netip.Addr) error { }, }, } - n.conn.InsertRule(dnatRule) - return n.conn.Flush() + if len(meta) > 0 { + rule.UserData = meta + } + return rule } // DNATWithLoadBalancer currently just forwards all traffic destined for origDst @@ -555,6 +563,8 @@ type NetfilterRunner interface { EnsurePortMapRuleForSvc(svc, tun string, targetIP netip.Addr, pm PortMap) error DeletePortMapRuleForSvc(svc, tun string, targetIP netip.Addr, pm PortMap) error + EnsureDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error + DeleteDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error DeleteSvc(svc, tun string, targetIPs []netip.Addr, pm []PortMap) error diff --git a/wgengine/router/router_linux_test.go b/wgengine/router/router_linux_test.go index 7ddd7385d..a289fb0ac 100644 --- a/wgengine/router/router_linux_test.go +++ b/wgengine/router/router_linux_test.go @@ -557,6 +557,14 @@ func (n *fakeIPTablesRunner) ClampMSSToPMTU(tun string, addr netip.Addr) error { return errors.New("not implemented") } +func (n *fakeIPTablesRunner) EnsureDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error { + return errors.New("not implemented") +} + +func (n *fakeIPTablesRunner) DeleteDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error { + return errors.New("not implemented") +} + func (n *fakeIPTablesRunner) addBase4(tunname string) error { curIPT := n.ipt4 newRules := []struct{ chain, rule string }{