util/linuxfw,wgengine/router: add new netfilter rules for HA ingresses

Add new rules to update DNAT rules for Kubernetes operator's
HA ingress where it's expected that rules will be added/removed
frequently (so we don't want to keep old rules around or rewrite
existing rules unnecessarily):
- allow deleting DNAT rules using metadata lookup
- allow inserting DNAT rules if they don't already
exist (using metadata lookup)

Updates tailscale/tailscale#15895

Co-authored-by: chaosinthecrd <tom@tmlabs.co.uk>
Signed-off-by: Irbe Krumina <irbe@tailscale.com>
This commit is contained in:
Irbe Krumina 2025-05-07 10:24:57 +01:00
parent 5b597489bc
commit 67ecebd9e2
7 changed files with 559 additions and 40 deletions

View File

@ -0,0 +1,95 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build linux
package linuxfw
import (
"net/netip"
"tailscale.com/types/logger"
)
// FakeNetfilterRunner is a fake netfilter runner for tests.
type FakeNetfilterRunner struct {
// services is a map that tracks the firewall rules added/deleted via
// EnsureDNATRuleForSvc/DeleteDNATRuleForSvc.
services map[string]struct {
VIPServiceIP netip.Addr
ClusterIP netip.Addr
}
}
// NewFakeNetfilterRunner creates a new FakeNetfilterRunner.
func NewFakeNetfilterRunner() *FakeNetfilterRunner {
return &FakeNetfilterRunner{
services: make(map[string]struct {
VIPServiceIP netip.Addr
ClusterIP netip.Addr
}),
}
}
func (f *FakeNetfilterRunner) EnsureDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error {
f.services[svcName] = struct {
VIPServiceIP netip.Addr
ClusterIP netip.Addr
}{origDst, dst}
return nil
}
func (f *FakeNetfilterRunner) DeleteDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error {
delete(f.services, svcName)
return nil
}
func (f *FakeNetfilterRunner) GetServiceState() map[string]struct {
VIPServiceIP netip.Addr
ClusterIP netip.Addr
} {
return f.services
}
func (f *FakeNetfilterRunner) HasIPV6() bool {
return true
}
func (f *FakeNetfilterRunner) HasIPV6Filter() bool {
return true
}
func (f *FakeNetfilterRunner) HasIPV6NAT() bool {
return true
}
func (f *FakeNetfilterRunner) AddBase(tunname string) error { return nil }
func (f *FakeNetfilterRunner) DelBase() error { return nil }
func (f *FakeNetfilterRunner) AddChains() error { return nil }
func (f *FakeNetfilterRunner) DelChains() error { return nil }
func (f *FakeNetfilterRunner) AddHooks() error { return nil }
func (f *FakeNetfilterRunner) DelHooks(logf logger.Logf) error { return nil }
func (f *FakeNetfilterRunner) AddSNATRule() error { return nil }
func (f *FakeNetfilterRunner) DelSNATRule() error { return nil }
func (f *FakeNetfilterRunner) AddStatefulRule(tunname string) error { return nil }
func (f *FakeNetfilterRunner) DelStatefulRule(tunname string) error { return nil }
func (f *FakeNetfilterRunner) AddLoopbackRule(addr netip.Addr) error { return nil }
func (f *FakeNetfilterRunner) DelLoopbackRule(addr netip.Addr) error { return nil }
func (f *FakeNetfilterRunner) AddDNATRule(origDst, dst netip.Addr) error { return nil }
func (f *FakeNetfilterRunner) DNATWithLoadBalancer(origDst netip.Addr, dsts []netip.Addr) error {
return nil
}
func (f *FakeNetfilterRunner) EnsureSNATForDst(src, dst netip.Addr) error { return nil }
func (f *FakeNetfilterRunner) DNATNonTailscaleTraffic(tun string, dst netip.Addr) error { return nil }
func (f *FakeNetfilterRunner) ClampMSSToPMTU(tun string, addr netip.Addr) error { return nil }
func (f *FakeNetfilterRunner) AddMagicsockPortRule(port uint16, network string) error { return nil }
func (f *FakeNetfilterRunner) DelMagicsockPortRule(port uint16, network string) error { return nil }
func (f *FakeNetfilterRunner) DeletePortMapRuleForSvc(svc, tun string, targetIP netip.Addr, pm PortMap) error {
return nil
}
func (f *FakeNetfilterRunner) DeleteSvc(svc, tun string, targetIPs []netip.Addr, pms []PortMap) error {
return nil
}
func (f *FakeNetfilterRunner) EnsurePortMapRuleForSvc(svc, tun string, targetIP netip.Addr, pm PortMap) error {
return nil
}

View File

@ -13,6 +13,7 @@ import (
// This file contains functionality to insert portmapping rules for a 'service'.
// These are currently only used by the Kubernetes operator proxies.
// An iptables rule for such a service contains a comment with the service name.
// A 'service' corresponds to a VIPService as used by the Kubernetes operator.
// EnsurePortMapRuleForSvc adds a prerouting rule that forwards traffic received
// on match port and NOT on the provided interface to target IP and target port.
@ -24,10 +25,10 @@ func (i *iptablesRunner) EnsurePortMapRuleForSvc(svc, tun string, targetIP netip
if err != nil {
return fmt.Errorf("error checking if rule exists: %w", err)
}
if !exists {
return table.Append("nat", "PREROUTING", args...)
if exists {
return nil
}
return nil
return table.Append("nat", "PREROUTING", args...)
}
// DeleteMapRuleForSvc constructs a prerouting rule as would be created by
@ -40,10 +41,41 @@ func (i *iptablesRunner) DeletePortMapRuleForSvc(svc, excludeI string, targetIP
if err != nil {
return fmt.Errorf("error checking if rule exists: %w", err)
}
if exists {
return table.Delete("nat", "PREROUTING", args...)
if !exists {
return nil
}
return nil
return table.Delete("nat", "PREROUTING", args...)
}
// EnsureDNATRuleForSvc adds a DNAT rule that forwards traffic from the
// VIPService IP address to a local address. This is used by the Kubernetes
// operator's network layer proxies to forward tailnet traffic for VIPServices
// to Kubernetes Services.
func (i *iptablesRunner) EnsureDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error {
table := i.getIPTByAddr(dst)
args := argsForIngressRule(svcName, origDst, dst)
exists, err := table.Exists("nat", "PREROUTING", args...)
if err != nil {
return fmt.Errorf("error checking if rule exists: %w", err)
}
if exists {
return nil
}
return table.Append("nat", "PREROUTING", args...)
}
// DeleteDNATRuleForSvc deletes a DNAT rule created by EnsureDNATRuleForSvc.
func (i *iptablesRunner) DeleteDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error {
table := i.getIPTByAddr(dst)
args := argsForIngressRule(svcName, origDst, dst)
exists, err := table.Exists("nat", "PREROUTING", args...)
if err != nil {
return fmt.Errorf("error checking if rule exists: %w", err)
}
if !exists {
return nil
}
return table.Delete("nat", "PREROUTING", args...)
}
// DeleteSvc constructs all possible rules that would have been created by
@ -72,8 +104,24 @@ func argsForPortMapRule(svc, excludeI string, targetIP netip.Addr, pm PortMap) [
}
}
func argsForIngressRule(svcName string, origDst, targetIP netip.Addr) []string {
c := commentForIngressSvc(svcName, origDst, targetIP)
return []string{
"--destination", origDst.String(),
"-m", "comment", "--comment", c,
"-j", "DNAT",
"--to-destination", targetIP.String(),
}
}
// commentForSvc generates a comment to be added to an iptables DNAT rule for a
// service. This is for iptables debugging/readability purposes only.
func commentForSvc(svc string, pm PortMap) string {
return fmt.Sprintf("%s:%s:%d -> %s:%d", svc, pm.Protocol, pm.MatchPort, pm.Protocol, pm.TargetPort)
}
// commentForIngressSvc generates a comment to be added to an iptables DNAT rule for a
// service. This is for iptables debugging/readability purposes only.
func commentForIngressSvc(svc string, vip, clusterIP netip.Addr) string {
return fmt.Sprintf("svc: %s, %s -> %s", svc, vip.String(), clusterIP.String())
}

View File

@ -153,6 +153,135 @@ func Test_iptablesRunner_DeleteSvc(t *testing.T) {
svcMustExist(t, "svc2", map[string][]string{v4Addr.String(): s2R1, v6Addr.String(): s2R2}, iptr)
}
func Test_iptablesRunner_EnsureDNATRuleForSvc(t *testing.T) {
v4OrigDst := netip.MustParseAddr("10.0.0.1")
v4Target := netip.MustParseAddr("10.0.0.2")
v6OrigDst := netip.MustParseAddr("fd7a:115c:a1e0::1")
v6Target := netip.MustParseAddr("fd7a:115c:a1e0::2")
v4Rule := argsForIngressRule("svc:test", v4OrigDst, v4Target)
tests := []struct {
name string
svcName string
origDst netip.Addr
targetIP netip.Addr
precreateSvcRules [][]string
}{
{
name: "dnat_for_ipv4",
svcName: "svc:test",
origDst: v4OrigDst,
targetIP: v4Target,
},
{
name: "dnat_for_ipv6",
svcName: "svc:test-2",
origDst: v6OrigDst,
targetIP: v6Target,
},
{
name: "add_existing_rule",
svcName: "svc:test",
origDst: v4OrigDst,
targetIP: v4Target,
precreateSvcRules: [][]string{v4Rule},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
iptr := NewFakeIPTablesRunner()
table := iptr.getIPTByAddr(tt.targetIP)
for _, ruleset := range tt.precreateSvcRules {
mustPrecreateDNATRule(t, ruleset, table)
}
if err := iptr.EnsureDNATRuleForSvc(tt.svcName, tt.origDst, tt.targetIP); err != nil {
t.Errorf("[unexpected error] iptablesRunner.EnsureDNATRuleForSvc() = %v", err)
}
args := argsForIngressRule(tt.svcName, tt.origDst, tt.targetIP)
exists, err := table.Exists("nat", "PREROUTING", args...)
if err != nil {
t.Fatalf("error checking if rule exists: %v", err)
}
if !exists {
t.Errorf("expected rule was not created")
}
})
}
}
func Test_iptablesRunner_DeleteDNATRuleForSvc(t *testing.T) {
v4OrigDst := netip.MustParseAddr("10.0.0.1")
v4Target := netip.MustParseAddr("10.0.0.2")
v6OrigDst := netip.MustParseAddr("fd7a:115c:a1e0::1")
v6Target := netip.MustParseAddr("fd7a:115c:a1e0::2")
v4Rule := argsForIngressRule("svc:test", v4OrigDst, v4Target)
v6Rule := argsForIngressRule("svc:test", v6OrigDst, v6Target)
tests := []struct {
name string
svcName string
origDst netip.Addr
targetIP netip.Addr
precreateSvcRules [][]string
}{
{
name: "multiple_rules_ipv4_deleted",
svcName: "svc:test",
origDst: v4OrigDst,
targetIP: v4Target,
precreateSvcRules: [][]string{v4Rule, v6Rule},
},
{
name: "multiple_rules_ipv6_deleted",
svcName: "svc:test",
origDst: v6OrigDst,
targetIP: v6Target,
precreateSvcRules: [][]string{v4Rule, v6Rule},
},
{
name: "non-existent_rule_deleted",
svcName: "svc:test",
origDst: v4OrigDst,
targetIP: v4Target,
precreateSvcRules: [][]string{v6Rule},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
iptr := NewFakeIPTablesRunner()
table := iptr.getIPTByAddr(tt.targetIP)
for _, ruleset := range tt.precreateSvcRules {
mustPrecreateDNATRule(t, ruleset, table)
}
if err := iptr.DeleteDNATRuleForSvc(tt.svcName, tt.origDst, tt.targetIP); err != nil {
t.Errorf("iptablesRunner.DeleteDNATRuleForSvc() errored: %v ", err)
}
deletedRule := argsForIngressRule(tt.svcName, tt.origDst, tt.targetIP)
exists, err := table.Exists("nat", "PREROUTING", deletedRule...)
if err != nil {
t.Fatalf("error verifying that rule does not exist after deletion: %v", err)
}
if exists {
t.Errorf("DNAT rule exists after deletion")
}
})
}
}
func mustPrecreateDNATRule(t *testing.T, rules []string, table iptablesInterface) {
t.Helper()
exists, err := table.Exists("nat", "PREROUTING", rules...)
if err != nil {
t.Fatalf("error ensuring that nat PREROUTING table exists: %v", err)
}
if exists {
return
}
if err := table.Append("nat", "PREROUTING", rules...); err != nil {
t.Fatalf("error precreating DNAT rule: %v", err)
}
}
func svcMustExist(t *testing.T, svcName string, rules map[string][]string, iptr *iptablesRunner) {
t.Helper()
for dst, ruleset := range rules {

View File

@ -119,6 +119,63 @@ func (n *nftablesRunner) DeleteSvc(svc, tun string, targetIPs []netip.Addr, pm [
return n.conn.Flush()
}
// EnsureDNATRuleForSvc adds a DNAT rule that forwards traffic from the
// VIPService IP address to a local address. This is used by the Kubernetes
// operator's network layer proxies to forward tailnet traffic for VIPServices
// to Kubernetes Services.
func (n *nftablesRunner) EnsureDNATRuleForSvc(svc string, origDst, dst netip.Addr) error {
t, ch, err := n.ensurePreroutingChain(origDst)
if err != nil {
return fmt.Errorf("error ensuring chain for %s: %w", svc, err)
}
meta := svcRuleMeta(svc, origDst, dst)
rule, err := n.findRuleByMetadata(t, ch, meta)
if err != nil {
return fmt.Errorf("error looking up rule: %w", err)
}
if rule != nil {
return nil
}
rule = dnatRuleForChain(t, ch, origDst, dst, meta)
n.conn.InsertRule(rule)
return n.conn.Flush()
}
// DeleteDNATRuleForSvc deletes a DNAT rule created by EnsureDNATRuleForSvc.
// We use the metadata attached to the rule to look it up.
func (n *nftablesRunner) DeleteDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error {
table, err := n.getNFTByAddr(origDst)
if err != nil {
return fmt.Errorf("error setting up nftables for IP family of %s: %w", origDst, err)
}
t, err := getTableIfExists(n.conn, table.Proto, "nat")
if err != nil {
return fmt.Errorf("error checking if nat table exists: %w", err)
}
if t == nil {
return nil
}
ch, err := getChainFromTable(n.conn, t, "PREROUTING")
if errors.Is(err, errorChainNotFound{tableName: "nat", chainName: "PREROUTING"}) {
return nil
}
if err != nil {
return fmt.Errorf("error checking if chain PREROUTING exists: %w", err)
}
meta := svcRuleMeta(svcName, origDst, dst)
rule, err := n.findRuleByMetadata(t, ch, meta)
if err != nil {
return fmt.Errorf("error checking if rule exists: %w", err)
}
if rule == nil {
return nil
}
if err := n.conn.DelRule(rule); err != nil {
return fmt.Errorf("error deleting rule: %w", err)
}
return n.conn.Flush()
}
func portMapRule(t *nftables.Table, ch *nftables.Chain, tun string, targetIP netip.Addr, matchPort, targetPort uint16, proto uint8, meta []byte) *nftables.Rule {
var fam uint32
if targetIP.Is4() {
@ -243,3 +300,10 @@ func protoFromString(s string) (uint8, error) {
return 0, fmt.Errorf("unrecognized protocol: %q", s)
}
}
// svcRuleMeta generates metadata for a rule.
// This metadata can then be used to find the rule.
// https://github.com/google/nftables/issues/48
func svcRuleMeta(svcName string, origDst, dst netip.Addr) []byte {
return []byte(fmt.Sprintf("svc:%s,VIP:%s,ClusterIP:%s", svcName, origDst.String(), dst.String()))
}

View File

@ -14,8 +14,9 @@ import (
// This test creates a temporary network namespace for the nftables rules being
// set up, so it needs to run in a privileged mode. Locally it needs to be run
// by root, else it will be silently skipped. In CI it runs in a privileged
// container.
// by root, else it will be silently skipped.
// sudo go test -v -run Test_nftablesRunner_EnsurePortMapRuleForSvc ./util/linuxfw/...
// In CI it runs in a privileged container.
func Test_nftablesRunner_EnsurePortMapRuleForSvc(t *testing.T) {
conn := newSysConn(t)
runner := newFakeNftablesRunnerWithConn(t, conn, true)
@ -23,51 +24,215 @@ func Test_nftablesRunner_EnsurePortMapRuleForSvc(t *testing.T) {
pmTCP := PortMap{MatchPort: 4003, TargetPort: 80, Protocol: "TCP"}
pmTCP1 := PortMap{MatchPort: 4004, TargetPort: 443, Protocol: "TCP"}
// Create a rule for service 'foo' to forward TCP traffic to IPv4 endpoint
runner.EnsurePortMapRuleForSvc("foo", "tailscale0", ipv4, pmTCP)
// Create a rule for service 'svc:foo' to forward TCP traffic to IPv4 endpoint
runner.EnsurePortMapRuleForSvc("svc:foo", "tailscale0", ipv4, pmTCP)
svcChains(t, 1, conn)
chainRuleCount(t, "foo", 1, conn, nftables.TableFamilyIPv4)
checkPortMapRule(t, "foo", ipv4, pmTCP, runner, nftables.TableFamilyIPv4)
chainRuleCount(t, "svc:foo", 1, conn, nftables.TableFamilyIPv4)
checkPortMapRule(t, "svc:foo", ipv4, pmTCP, runner, nftables.TableFamilyIPv4)
// Create another rule for service 'foo' to forward TCP traffic to the
// Create another rule for service 'svc:foo' to forward TCP traffic to the
// same IPv4 endpoint, but to a different port.
runner.EnsurePortMapRuleForSvc("foo", "tailscale0", ipv4, pmTCP1)
runner.EnsurePortMapRuleForSvc("svc:foo", "tailscale0", ipv4, pmTCP1)
svcChains(t, 1, conn)
chainRuleCount(t, "foo", 2, conn, nftables.TableFamilyIPv4)
checkPortMapRule(t, "foo", ipv4, pmTCP1, runner, nftables.TableFamilyIPv4)
chainRuleCount(t, "svc:foo", 2, conn, nftables.TableFamilyIPv4)
checkPortMapRule(t, "svc:foo", ipv4, pmTCP1, runner, nftables.TableFamilyIPv4)
// Create a rule for service 'foo' to forward TCP traffic to an IPv6 endpoint
runner.EnsurePortMapRuleForSvc("foo", "tailscale0", ipv6, pmTCP)
// Create a rule for service 'svc:foo' to forward TCP traffic to an IPv6 endpoint
runner.EnsurePortMapRuleForSvc("svc:foo", "tailscale0", ipv6, pmTCP)
svcChains(t, 2, conn)
chainRuleCount(t, "foo", 1, conn, nftables.TableFamilyIPv6)
checkPortMapRule(t, "foo", ipv6, pmTCP, runner, nftables.TableFamilyIPv6)
chainRuleCount(t, "svc:foo", 1, conn, nftables.TableFamilyIPv6)
checkPortMapRule(t, "svc:foo", ipv6, pmTCP, runner, nftables.TableFamilyIPv6)
// Create a rule for service 'bar' to forward TCP traffic to IPv4 endpoint
runner.EnsurePortMapRuleForSvc("bar", "tailscale0", ipv4, pmTCP)
// Create a rule for service 'svc:bar' to forward TCP traffic to IPv4 endpoint
runner.EnsurePortMapRuleForSvc("svc:bar", "tailscale0", ipv4, pmTCP)
svcChains(t, 3, conn)
chainRuleCount(t, "bar", 1, conn, nftables.TableFamilyIPv4)
checkPortMapRule(t, "bar", ipv4, pmTCP, runner, nftables.TableFamilyIPv4)
chainRuleCount(t, "svc:bar", 1, conn, nftables.TableFamilyIPv4)
checkPortMapRule(t, "svc:bar", ipv4, pmTCP, runner, nftables.TableFamilyIPv4)
// Create a rule for service 'bar' to forward TCP traffic to an IPv6 endpoint
runner.EnsurePortMapRuleForSvc("bar", "tailscale0", ipv6, pmTCP)
// Create a rule for service 'svc:bar' to forward TCP traffic to an IPv6 endpoint
runner.EnsurePortMapRuleForSvc("svc:bar", "tailscale0", ipv6, pmTCP)
svcChains(t, 4, conn)
chainRuleCount(t, "bar", 1, conn, nftables.TableFamilyIPv6)
checkPortMapRule(t, "bar", ipv6, pmTCP, runner, nftables.TableFamilyIPv6)
chainRuleCount(t, "svc:bar", 1, conn, nftables.TableFamilyIPv6)
checkPortMapRule(t, "svc:bar", ipv6, pmTCP, runner, nftables.TableFamilyIPv6)
// Delete service bar
runner.DeleteSvc("bar", "tailscale0", []netip.Addr{ipv4, ipv6}, []PortMap{pmTCP})
// Delete service svc:bar
runner.DeleteSvc("svc:bar", "tailscale0", []netip.Addr{ipv4, ipv6}, []PortMap{pmTCP})
svcChains(t, 2, conn)
// Delete a rule from service foo
runner.DeletePortMapRuleForSvc("foo", "tailscale0", ipv4, pmTCP)
// Delete a rule from service svc:foo
runner.DeletePortMapRuleForSvc("svc:foo", "tailscale0", ipv4, pmTCP)
svcChains(t, 2, conn)
chainRuleCount(t, "foo", 1, conn, nftables.TableFamilyIPv4)
chainRuleCount(t, "svc:foo", 1, conn, nftables.TableFamilyIPv4)
// Delete service foo
runner.DeleteSvc("foo", "tailscale0", []netip.Addr{ipv4, ipv6}, []PortMap{pmTCP, pmTCP1})
// Delete service svc:foo
runner.DeleteSvc("svc:foo", "tailscale0", []netip.Addr{ipv4, ipv6}, []PortMap{pmTCP, pmTCP1})
svcChains(t, 0, conn)
}
func Test_nftablesRunner_EnsureDNATRuleForSvc(t *testing.T) {
conn := newSysConn(t)
runner := newFakeNftablesRunnerWithConn(t, conn, true)
// Test IPv4 DNAT rule
ipv4OrigDst := netip.MustParseAddr("10.0.0.1")
ipv4Target := netip.MustParseAddr("10.0.0.2")
// Create DNAT rule for service 'svc:foo' to forward IPv4 traffic
err := runner.EnsureDNATRuleForSvc("svc:foo", ipv4OrigDst, ipv4Target)
if err != nil {
t.Fatalf("error creating IPv4 DNAT rule: %v", err)
}
checkDNATRule(t, "svc:foo", ipv4OrigDst, ipv4Target, runner, nftables.TableFamilyIPv4)
// Test IPv6 DNAT rule
ipv6OrigDst := netip.MustParseAddr("fd7a:115c:a1e0::1")
ipv6Target := netip.MustParseAddr("fd7a:115c:a1e0::2")
// Create DNAT rule for service 'svc:foo' to forward IPv6 traffic
err = runner.EnsureDNATRuleForSvc("svc:foo", ipv6OrigDst, ipv6Target)
if err != nil {
t.Fatalf("error creating IPv6 DNAT rule: %v", err)
}
checkDNATRule(t, "svc:foo", ipv6OrigDst, ipv6Target, runner, nftables.TableFamilyIPv6)
// Test creating rule for another service
err = runner.EnsureDNATRuleForSvc("svc:bar", ipv4OrigDst, ipv4Target)
if err != nil {
t.Fatalf("error creating DNAT rule for service 'svc:bar': %v", err)
}
checkDNATRule(t, "svc:bar", ipv4OrigDst, ipv4Target, runner, nftables.TableFamilyIPv4)
}
func Test_nftablesRunner_DeleteDNATRuleForSvc(t *testing.T) {
conn := newSysConn(t)
runner := newFakeNftablesRunnerWithConn(t, conn, true)
// Test IPv4 DNAT rule deletion
ipv4OrigDst := netip.MustParseAddr("10.0.0.1")
ipv4Target := netip.MustParseAddr("10.0.0.2")
// Create and then delete IPv4 DNAT rule
err := runner.EnsureDNATRuleForSvc("svc:foo", ipv4OrigDst, ipv4Target)
if err != nil {
t.Fatalf("error creating IPv4 DNAT rule: %v", err)
}
// Verify rule exists before deletion
table, err := runner.getNFTByAddr(ipv4OrigDst)
if err != nil {
t.Fatalf("error getting table: %v", err)
}
nftTable, err := getTableIfExists(runner.conn, table.Proto, "nat")
if err != nil {
t.Fatalf("error getting nat table: %v", err)
}
ch, err := getChainFromTable(runner.conn, nftTable, "PREROUTING")
if err != nil {
t.Fatalf("error getting PREROUTING chain: %v", err)
}
meta := svcRuleMeta("svc:foo", ipv4OrigDst, ipv4Target)
rule, err := runner.findRuleByMetadata(nftTable, ch, meta)
if err != nil {
t.Fatalf("error checking if rule exists: %v", err)
}
if rule == nil {
t.Fatal("rule does not exist before deletion")
}
err = runner.DeleteDNATRuleForSvc("svc:foo", ipv4OrigDst, ipv4Target)
if err != nil {
t.Fatalf("error deleting IPv4 DNAT rule: %v", err)
}
// Verify rule is deleted
rule, err = runner.findRuleByMetadata(nftTable, ch, meta)
if err != nil {
t.Fatalf("error checking if rule exists: %v", err)
}
if rule != nil {
t.Fatal("rule still exists after deletion")
}
// Test IPv6 DNAT rule deletion
ipv6OrigDst := netip.MustParseAddr("fd7a:115c:a1e0::1")
ipv6Target := netip.MustParseAddr("fd7a:115c:a1e0::2")
// Create and then delete IPv6 DNAT rule
err = runner.EnsureDNATRuleForSvc("svc:foo", ipv6OrigDst, ipv6Target)
if err != nil {
t.Fatalf("error creating IPv6 DNAT rule: %v", err)
}
// Verify rule exists before deletion
table, err = runner.getNFTByAddr(ipv6OrigDst)
if err != nil {
t.Fatalf("error getting table: %v", err)
}
nftTable, err = getTableIfExists(runner.conn, table.Proto, "nat")
if err != nil {
t.Fatalf("error getting nat table: %v", err)
}
ch, err = getChainFromTable(runner.conn, nftTable, "PREROUTING")
if err != nil {
t.Fatalf("error getting PREROUTING chain: %v", err)
}
meta = svcRuleMeta("svc:foo", ipv6OrigDst, ipv6Target)
rule, err = runner.findRuleByMetadata(nftTable, ch, meta)
if err != nil {
t.Fatalf("error checking if rule exists: %v", err)
}
if rule == nil {
t.Fatal("rule does not exist before deletion")
}
err = runner.DeleteDNATRuleForSvc("svc:foo", ipv6OrigDst, ipv6Target)
if err != nil {
t.Fatalf("error deleting IPv6 DNAT rule: %v", err)
}
// Verify rule is deleted
rule, err = runner.findRuleByMetadata(nftTable, ch, meta)
if err != nil {
t.Fatalf("error checking if rule exists: %v", err)
}
if rule != nil {
t.Fatal("rule still exists after deletion")
}
}
// checkDNATRule verifies that a DNAT rule exists for the given service, original destination, and target IP.
func checkDNATRule(t *testing.T, svc string, origDst, targetIP netip.Addr, runner *nftablesRunner, fam nftables.TableFamily) {
t.Helper()
table, err := runner.getNFTByAddr(origDst)
if err != nil {
t.Fatalf("error getting table: %v", err)
}
nftTable, err := getTableIfExists(runner.conn, table.Proto, "nat")
if err != nil {
t.Fatalf("error getting nat table: %v", err)
}
if nftTable == nil {
t.Fatal("nat table not found")
}
ch, err := getChainFromTable(runner.conn, nftTable, "PREROUTING")
if err != nil {
t.Fatalf("error getting PREROUTING chain: %v", err)
}
if ch == nil {
t.Fatal("PREROUTING chain not found")
}
meta := svcRuleMeta(svc, origDst, targetIP)
rule, err := runner.findRuleByMetadata(nftTable, ch, meta)
if err != nil {
t.Fatalf("error checking if rule exists: %v", err)
}
if rule == nil {
t.Fatal("DNAT rule not found")
}
}
// svcChains verifies that the expected number of chains exist (for either IP
// family) and that each of them is configured as NAT prerouting chain.
func svcChains(t *testing.T, wantCount int, conn *nftables.Conn) {

View File

@ -107,6 +107,12 @@ func (n *nftablesRunner) AddDNATRule(origDst netip.Addr, dst netip.Addr) error {
if err != nil {
return err
}
rule := dnatRuleForChain(nat, preroutingCh, origDst, dst, nil)
n.conn.InsertRule(rule)
return n.conn.Flush()
}
func dnatRuleForChain(t *nftables.Table, ch *nftables.Chain, origDst, dst netip.Addr, meta []byte) *nftables.Rule {
var daddrOffset, fam, dadderLen uint32
if origDst.Is4() {
daddrOffset = 16
@ -117,9 +123,9 @@ func (n *nftablesRunner) AddDNATRule(origDst netip.Addr, dst netip.Addr) error {
dadderLen = 16
fam = unix.NFPROTO_IPV6
}
dnatRule := &nftables.Rule{
Table: nat,
Chain: preroutingCh,
rule := &nftables.Rule{
Table: t,
Chain: ch,
Exprs: []expr.Any{
&expr.Payload{
DestRegister: 1,
@ -143,8 +149,10 @@ func (n *nftablesRunner) AddDNATRule(origDst netip.Addr, dst netip.Addr) error {
},
},
}
n.conn.InsertRule(dnatRule)
return n.conn.Flush()
if len(meta) > 0 {
rule.UserData = meta
}
return rule
}
// DNATWithLoadBalancer currently just forwards all traffic destined for origDst
@ -555,6 +563,8 @@ type NetfilterRunner interface {
EnsurePortMapRuleForSvc(svc, tun string, targetIP netip.Addr, pm PortMap) error
DeletePortMapRuleForSvc(svc, tun string, targetIP netip.Addr, pm PortMap) error
EnsureDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error
DeleteDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error
DeleteSvc(svc, tun string, targetIPs []netip.Addr, pm []PortMap) error

View File

@ -557,6 +557,14 @@ func (n *fakeIPTablesRunner) ClampMSSToPMTU(tun string, addr netip.Addr) error {
return errors.New("not implemented")
}
func (n *fakeIPTablesRunner) EnsureDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error {
return errors.New("not implemented")
}
func (n *fakeIPTablesRunner) DeleteDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error {
return errors.New("not implemented")
}
func (n *fakeIPTablesRunner) addBase4(tunname string) error {
curIPT := n.ipt4
newRules := []struct{ chain, rule string }{