util/linuxfw,wgengine/router: add new netfilter rules for HA ingresses (#15896)

Add new rules to update DNAT rules for Kubernetes operator's
HA ingress where it's expected that rules will be added/removed
frequently (so we don't want to keep old rules around or rewrite
existing rules unnecessarily):
- allow deleting DNAT rules using metadata lookup
- allow inserting DNAT rules if they don't already
exist (using metadata lookup)

Updates tailscale/tailscale#15895

Signed-off-by: Irbe Krumina <irbe@tailscale.com>
Co-authored-by: chaosinthecrd <tom@tmlabs.co.uk>
This commit is contained in:
Irbe Krumina 2025-05-12 17:26:23 +01:00 committed by GitHub
parent d6dd74fe0e
commit 2c16fcaa06
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 559 additions and 40 deletions

View File

@ -0,0 +1,95 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build linux
package linuxfw
import (
"net/netip"
"tailscale.com/types/logger"
)
// FakeNetfilterRunner is a fake netfilter runner for tests.
type FakeNetfilterRunner struct {
// services is a map that tracks the firewall rules added/deleted via
// EnsureDNATRuleForSvc/DeleteDNATRuleForSvc.
services map[string]struct {
VIPServiceIP netip.Addr
ClusterIP netip.Addr
}
}
// NewFakeNetfilterRunner creates a new FakeNetfilterRunner.
func NewFakeNetfilterRunner() *FakeNetfilterRunner {
return &FakeNetfilterRunner{
services: make(map[string]struct {
VIPServiceIP netip.Addr
ClusterIP netip.Addr
}),
}
}
func (f *FakeNetfilterRunner) EnsureDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error {
f.services[svcName] = struct {
VIPServiceIP netip.Addr
ClusterIP netip.Addr
}{origDst, dst}
return nil
}
func (f *FakeNetfilterRunner) DeleteDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error {
delete(f.services, svcName)
return nil
}
func (f *FakeNetfilterRunner) GetServiceState() map[string]struct {
VIPServiceIP netip.Addr
ClusterIP netip.Addr
} {
return f.services
}
func (f *FakeNetfilterRunner) HasIPV6() bool {
return true
}
func (f *FakeNetfilterRunner) HasIPV6Filter() bool {
return true
}
func (f *FakeNetfilterRunner) HasIPV6NAT() bool {
return true
}
func (f *FakeNetfilterRunner) AddBase(tunname string) error { return nil }
func (f *FakeNetfilterRunner) DelBase() error { return nil }
func (f *FakeNetfilterRunner) AddChains() error { return nil }
func (f *FakeNetfilterRunner) DelChains() error { return nil }
func (f *FakeNetfilterRunner) AddHooks() error { return nil }
func (f *FakeNetfilterRunner) DelHooks(logf logger.Logf) error { return nil }
func (f *FakeNetfilterRunner) AddSNATRule() error { return nil }
func (f *FakeNetfilterRunner) DelSNATRule() error { return nil }
func (f *FakeNetfilterRunner) AddStatefulRule(tunname string) error { return nil }
func (f *FakeNetfilterRunner) DelStatefulRule(tunname string) error { return nil }
func (f *FakeNetfilterRunner) AddLoopbackRule(addr netip.Addr) error { return nil }
func (f *FakeNetfilterRunner) DelLoopbackRule(addr netip.Addr) error { return nil }
func (f *FakeNetfilterRunner) AddDNATRule(origDst, dst netip.Addr) error { return nil }
func (f *FakeNetfilterRunner) DNATWithLoadBalancer(origDst netip.Addr, dsts []netip.Addr) error {
return nil
}
func (f *FakeNetfilterRunner) EnsureSNATForDst(src, dst netip.Addr) error { return nil }
func (f *FakeNetfilterRunner) DNATNonTailscaleTraffic(tun string, dst netip.Addr) error { return nil }
func (f *FakeNetfilterRunner) ClampMSSToPMTU(tun string, addr netip.Addr) error { return nil }
func (f *FakeNetfilterRunner) AddMagicsockPortRule(port uint16, network string) error { return nil }
func (f *FakeNetfilterRunner) DelMagicsockPortRule(port uint16, network string) error { return nil }
func (f *FakeNetfilterRunner) DeletePortMapRuleForSvc(svc, tun string, targetIP netip.Addr, pm PortMap) error {
return nil
}
func (f *FakeNetfilterRunner) DeleteSvc(svc, tun string, targetIPs []netip.Addr, pms []PortMap) error {
return nil
}
func (f *FakeNetfilterRunner) EnsurePortMapRuleForSvc(svc, tun string, targetIP netip.Addr, pm PortMap) error {
return nil
}

View File

@ -13,6 +13,7 @@ import (
// This file contains functionality to insert portmapping rules for a 'service'. // This file contains functionality to insert portmapping rules for a 'service'.
// These are currently only used by the Kubernetes operator proxies. // These are currently only used by the Kubernetes operator proxies.
// An iptables rule for such a service contains a comment with the service name. // An iptables rule for such a service contains a comment with the service name.
// A 'service' corresponds to a VIPService as used by the Kubernetes operator.
// EnsurePortMapRuleForSvc adds a prerouting rule that forwards traffic received // EnsurePortMapRuleForSvc adds a prerouting rule that forwards traffic received
// on match port and NOT on the provided interface to target IP and target port. // on match port and NOT on the provided interface to target IP and target port.
@ -24,11 +25,11 @@ func (i *iptablesRunner) EnsurePortMapRuleForSvc(svc, tun string, targetIP netip
if err != nil { if err != nil {
return fmt.Errorf("error checking if rule exists: %w", err) return fmt.Errorf("error checking if rule exists: %w", err)
} }
if !exists { if exists {
return table.Append("nat", "PREROUTING", args...)
}
return nil return nil
} }
return table.Append("nat", "PREROUTING", args...)
}
// DeleteMapRuleForSvc constructs a prerouting rule as would be created by // DeleteMapRuleForSvc constructs a prerouting rule as would be created by
// EnsurePortMapRuleForSvc with the provided args and, if such a rule exists, // EnsurePortMapRuleForSvc with the provided args and, if such a rule exists,
@ -40,11 +41,42 @@ func (i *iptablesRunner) DeletePortMapRuleForSvc(svc, excludeI string, targetIP
if err != nil { if err != nil {
return fmt.Errorf("error checking if rule exists: %w", err) return fmt.Errorf("error checking if rule exists: %w", err)
} }
if exists { if !exists {
return nil
}
return table.Delete("nat", "PREROUTING", args...) return table.Delete("nat", "PREROUTING", args...)
} }
// EnsureDNATRuleForSvc adds a DNAT rule that forwards traffic from the
// VIPService IP address to a local address. This is used by the Kubernetes
// operator's network layer proxies to forward tailnet traffic for VIPServices
// to Kubernetes Services.
func (i *iptablesRunner) EnsureDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error {
table := i.getIPTByAddr(dst)
args := argsForIngressRule(svcName, origDst, dst)
exists, err := table.Exists("nat", "PREROUTING", args...)
if err != nil {
return fmt.Errorf("error checking if rule exists: %w", err)
}
if exists {
return nil return nil
} }
return table.Append("nat", "PREROUTING", args...)
}
// DeleteDNATRuleForSvc deletes a DNAT rule created by EnsureDNATRuleForSvc.
func (i *iptablesRunner) DeleteDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error {
table := i.getIPTByAddr(dst)
args := argsForIngressRule(svcName, origDst, dst)
exists, err := table.Exists("nat", "PREROUTING", args...)
if err != nil {
return fmt.Errorf("error checking if rule exists: %w", err)
}
if !exists {
return nil
}
return table.Delete("nat", "PREROUTING", args...)
}
// DeleteSvc constructs all possible rules that would have been created by // DeleteSvc constructs all possible rules that would have been created by
// EnsurePortMapRuleForSvc from the provided args and ensures that each one that // EnsurePortMapRuleForSvc from the provided args and ensures that each one that
@ -72,8 +104,24 @@ func argsForPortMapRule(svc, excludeI string, targetIP netip.Addr, pm PortMap) [
} }
} }
func argsForIngressRule(svcName string, origDst, targetIP netip.Addr) []string {
c := commentForIngressSvc(svcName, origDst, targetIP)
return []string{
"--destination", origDst.String(),
"-m", "comment", "--comment", c,
"-j", "DNAT",
"--to-destination", targetIP.String(),
}
}
// commentForSvc generates a comment to be added to an iptables DNAT rule for a // commentForSvc generates a comment to be added to an iptables DNAT rule for a
// service. This is for iptables debugging/readability purposes only. // service. This is for iptables debugging/readability purposes only.
func commentForSvc(svc string, pm PortMap) string { func commentForSvc(svc string, pm PortMap) string {
return fmt.Sprintf("%s:%s:%d -> %s:%d", svc, pm.Protocol, pm.MatchPort, pm.Protocol, pm.TargetPort) return fmt.Sprintf("%s:%s:%d -> %s:%d", svc, pm.Protocol, pm.MatchPort, pm.Protocol, pm.TargetPort)
} }
// commentForIngressSvc generates a comment to be added to an iptables DNAT rule for a
// service. This is for iptables debugging/readability purposes only.
func commentForIngressSvc(svc string, vip, clusterIP netip.Addr) string {
return fmt.Sprintf("svc: %s, %s -> %s", svc, vip.String(), clusterIP.String())
}

View File

@ -153,6 +153,135 @@ func Test_iptablesRunner_DeleteSvc(t *testing.T) {
svcMustExist(t, "svc2", map[string][]string{v4Addr.String(): s2R1, v6Addr.String(): s2R2}, iptr) svcMustExist(t, "svc2", map[string][]string{v4Addr.String(): s2R1, v6Addr.String(): s2R2}, iptr)
} }
func Test_iptablesRunner_EnsureDNATRuleForSvc(t *testing.T) {
v4OrigDst := netip.MustParseAddr("10.0.0.1")
v4Target := netip.MustParseAddr("10.0.0.2")
v6OrigDst := netip.MustParseAddr("fd7a:115c:a1e0::1")
v6Target := netip.MustParseAddr("fd7a:115c:a1e0::2")
v4Rule := argsForIngressRule("svc:test", v4OrigDst, v4Target)
tests := []struct {
name string
svcName string
origDst netip.Addr
targetIP netip.Addr
precreateSvcRules [][]string
}{
{
name: "dnat_for_ipv4",
svcName: "svc:test",
origDst: v4OrigDst,
targetIP: v4Target,
},
{
name: "dnat_for_ipv6",
svcName: "svc:test-2",
origDst: v6OrigDst,
targetIP: v6Target,
},
{
name: "add_existing_rule",
svcName: "svc:test",
origDst: v4OrigDst,
targetIP: v4Target,
precreateSvcRules: [][]string{v4Rule},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
iptr := NewFakeIPTablesRunner()
table := iptr.getIPTByAddr(tt.targetIP)
for _, ruleset := range tt.precreateSvcRules {
mustPrecreateDNATRule(t, ruleset, table)
}
if err := iptr.EnsureDNATRuleForSvc(tt.svcName, tt.origDst, tt.targetIP); err != nil {
t.Errorf("[unexpected error] iptablesRunner.EnsureDNATRuleForSvc() = %v", err)
}
args := argsForIngressRule(tt.svcName, tt.origDst, tt.targetIP)
exists, err := table.Exists("nat", "PREROUTING", args...)
if err != nil {
t.Fatalf("error checking if rule exists: %v", err)
}
if !exists {
t.Errorf("expected rule was not created")
}
})
}
}
func Test_iptablesRunner_DeleteDNATRuleForSvc(t *testing.T) {
v4OrigDst := netip.MustParseAddr("10.0.0.1")
v4Target := netip.MustParseAddr("10.0.0.2")
v6OrigDst := netip.MustParseAddr("fd7a:115c:a1e0::1")
v6Target := netip.MustParseAddr("fd7a:115c:a1e0::2")
v4Rule := argsForIngressRule("svc:test", v4OrigDst, v4Target)
v6Rule := argsForIngressRule("svc:test", v6OrigDst, v6Target)
tests := []struct {
name string
svcName string
origDst netip.Addr
targetIP netip.Addr
precreateSvcRules [][]string
}{
{
name: "multiple_rules_ipv4_deleted",
svcName: "svc:test",
origDst: v4OrigDst,
targetIP: v4Target,
precreateSvcRules: [][]string{v4Rule, v6Rule},
},
{
name: "multiple_rules_ipv6_deleted",
svcName: "svc:test",
origDst: v6OrigDst,
targetIP: v6Target,
precreateSvcRules: [][]string{v4Rule, v6Rule},
},
{
name: "non-existent_rule_deleted",
svcName: "svc:test",
origDst: v4OrigDst,
targetIP: v4Target,
precreateSvcRules: [][]string{v6Rule},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
iptr := NewFakeIPTablesRunner()
table := iptr.getIPTByAddr(tt.targetIP)
for _, ruleset := range tt.precreateSvcRules {
mustPrecreateDNATRule(t, ruleset, table)
}
if err := iptr.DeleteDNATRuleForSvc(tt.svcName, tt.origDst, tt.targetIP); err != nil {
t.Errorf("iptablesRunner.DeleteDNATRuleForSvc() errored: %v ", err)
}
deletedRule := argsForIngressRule(tt.svcName, tt.origDst, tt.targetIP)
exists, err := table.Exists("nat", "PREROUTING", deletedRule...)
if err != nil {
t.Fatalf("error verifying that rule does not exist after deletion: %v", err)
}
if exists {
t.Errorf("DNAT rule exists after deletion")
}
})
}
}
func mustPrecreateDNATRule(t *testing.T, rules []string, table iptablesInterface) {
t.Helper()
exists, err := table.Exists("nat", "PREROUTING", rules...)
if err != nil {
t.Fatalf("error ensuring that nat PREROUTING table exists: %v", err)
}
if exists {
return
}
if err := table.Append("nat", "PREROUTING", rules...); err != nil {
t.Fatalf("error precreating DNAT rule: %v", err)
}
}
func svcMustExist(t *testing.T, svcName string, rules map[string][]string, iptr *iptablesRunner) { func svcMustExist(t *testing.T, svcName string, rules map[string][]string, iptr *iptablesRunner) {
t.Helper() t.Helper()
for dst, ruleset := range rules { for dst, ruleset := range rules {

View File

@ -119,6 +119,63 @@ func (n *nftablesRunner) DeleteSvc(svc, tun string, targetIPs []netip.Addr, pm [
return n.conn.Flush() return n.conn.Flush()
} }
// EnsureDNATRuleForSvc adds a DNAT rule that forwards traffic from the
// VIPService IP address to a local address. This is used by the Kubernetes
// operator's network layer proxies to forward tailnet traffic for VIPServices
// to Kubernetes Services.
func (n *nftablesRunner) EnsureDNATRuleForSvc(svc string, origDst, dst netip.Addr) error {
t, ch, err := n.ensurePreroutingChain(origDst)
if err != nil {
return fmt.Errorf("error ensuring chain for %s: %w", svc, err)
}
meta := svcRuleMeta(svc, origDst, dst)
rule, err := n.findRuleByMetadata(t, ch, meta)
if err != nil {
return fmt.Errorf("error looking up rule: %w", err)
}
if rule != nil {
return nil
}
rule = dnatRuleForChain(t, ch, origDst, dst, meta)
n.conn.InsertRule(rule)
return n.conn.Flush()
}
// DeleteDNATRuleForSvc deletes a DNAT rule created by EnsureDNATRuleForSvc.
// We use the metadata attached to the rule to look it up.
func (n *nftablesRunner) DeleteDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error {
table, err := n.getNFTByAddr(origDst)
if err != nil {
return fmt.Errorf("error setting up nftables for IP family of %s: %w", origDst, err)
}
t, err := getTableIfExists(n.conn, table.Proto, "nat")
if err != nil {
return fmt.Errorf("error checking if nat table exists: %w", err)
}
if t == nil {
return nil
}
ch, err := getChainFromTable(n.conn, t, "PREROUTING")
if errors.Is(err, errorChainNotFound{tableName: "nat", chainName: "PREROUTING"}) {
return nil
}
if err != nil {
return fmt.Errorf("error checking if chain PREROUTING exists: %w", err)
}
meta := svcRuleMeta(svcName, origDst, dst)
rule, err := n.findRuleByMetadata(t, ch, meta)
if err != nil {
return fmt.Errorf("error checking if rule exists: %w", err)
}
if rule == nil {
return nil
}
if err := n.conn.DelRule(rule); err != nil {
return fmt.Errorf("error deleting rule: %w", err)
}
return n.conn.Flush()
}
func portMapRule(t *nftables.Table, ch *nftables.Chain, tun string, targetIP netip.Addr, matchPort, targetPort uint16, proto uint8, meta []byte) *nftables.Rule { func portMapRule(t *nftables.Table, ch *nftables.Chain, tun string, targetIP netip.Addr, matchPort, targetPort uint16, proto uint8, meta []byte) *nftables.Rule {
var fam uint32 var fam uint32
if targetIP.Is4() { if targetIP.Is4() {
@ -243,3 +300,10 @@ func protoFromString(s string) (uint8, error) {
return 0, fmt.Errorf("unrecognized protocol: %q", s) return 0, fmt.Errorf("unrecognized protocol: %q", s)
} }
} }
// svcRuleMeta generates metadata for a rule.
// This metadata can then be used to find the rule.
// https://github.com/google/nftables/issues/48
func svcRuleMeta(svcName string, origDst, dst netip.Addr) []byte {
return []byte(fmt.Sprintf("svc:%s,VIP:%s,ClusterIP:%s", svcName, origDst.String(), dst.String()))
}

View File

@ -14,8 +14,9 @@ import (
// This test creates a temporary network namespace for the nftables rules being // This test creates a temporary network namespace for the nftables rules being
// set up, so it needs to run in a privileged mode. Locally it needs to be run // set up, so it needs to run in a privileged mode. Locally it needs to be run
// by root, else it will be silently skipped. In CI it runs in a privileged // by root, else it will be silently skipped.
// container. // sudo go test -v -run Test_nftablesRunner_EnsurePortMapRuleForSvc ./util/linuxfw/...
// In CI it runs in a privileged container.
func Test_nftablesRunner_EnsurePortMapRuleForSvc(t *testing.T) { func Test_nftablesRunner_EnsurePortMapRuleForSvc(t *testing.T) {
conn := newSysConn(t) conn := newSysConn(t)
runner := newFakeNftablesRunnerWithConn(t, conn, true) runner := newFakeNftablesRunnerWithConn(t, conn, true)
@ -23,51 +24,215 @@ func Test_nftablesRunner_EnsurePortMapRuleForSvc(t *testing.T) {
pmTCP := PortMap{MatchPort: 4003, TargetPort: 80, Protocol: "TCP"} pmTCP := PortMap{MatchPort: 4003, TargetPort: 80, Protocol: "TCP"}
pmTCP1 := PortMap{MatchPort: 4004, TargetPort: 443, Protocol: "TCP"} pmTCP1 := PortMap{MatchPort: 4004, TargetPort: 443, Protocol: "TCP"}
// Create a rule for service 'foo' to forward TCP traffic to IPv4 endpoint // Create a rule for service 'svc:foo' to forward TCP traffic to IPv4 endpoint
runner.EnsurePortMapRuleForSvc("foo", "tailscale0", ipv4, pmTCP) runner.EnsurePortMapRuleForSvc("svc:foo", "tailscale0", ipv4, pmTCP)
svcChains(t, 1, conn) svcChains(t, 1, conn)
chainRuleCount(t, "foo", 1, conn, nftables.TableFamilyIPv4) chainRuleCount(t, "svc:foo", 1, conn, nftables.TableFamilyIPv4)
checkPortMapRule(t, "foo", ipv4, pmTCP, runner, nftables.TableFamilyIPv4) checkPortMapRule(t, "svc:foo", ipv4, pmTCP, runner, nftables.TableFamilyIPv4)
// Create another rule for service 'foo' to forward TCP traffic to the // Create another rule for service 'svc:foo' to forward TCP traffic to the
// same IPv4 endpoint, but to a different port. // same IPv4 endpoint, but to a different port.
runner.EnsurePortMapRuleForSvc("foo", "tailscale0", ipv4, pmTCP1) runner.EnsurePortMapRuleForSvc("svc:foo", "tailscale0", ipv4, pmTCP1)
svcChains(t, 1, conn) svcChains(t, 1, conn)
chainRuleCount(t, "foo", 2, conn, nftables.TableFamilyIPv4) chainRuleCount(t, "svc:foo", 2, conn, nftables.TableFamilyIPv4)
checkPortMapRule(t, "foo", ipv4, pmTCP1, runner, nftables.TableFamilyIPv4) checkPortMapRule(t, "svc:foo", ipv4, pmTCP1, runner, nftables.TableFamilyIPv4)
// Create a rule for service 'foo' to forward TCP traffic to an IPv6 endpoint // Create a rule for service 'svc:foo' to forward TCP traffic to an IPv6 endpoint
runner.EnsurePortMapRuleForSvc("foo", "tailscale0", ipv6, pmTCP) runner.EnsurePortMapRuleForSvc("svc:foo", "tailscale0", ipv6, pmTCP)
svcChains(t, 2, conn) svcChains(t, 2, conn)
chainRuleCount(t, "foo", 1, conn, nftables.TableFamilyIPv6) chainRuleCount(t, "svc:foo", 1, conn, nftables.TableFamilyIPv6)
checkPortMapRule(t, "foo", ipv6, pmTCP, runner, nftables.TableFamilyIPv6) checkPortMapRule(t, "svc:foo", ipv6, pmTCP, runner, nftables.TableFamilyIPv6)
// Create a rule for service 'bar' to forward TCP traffic to IPv4 endpoint // Create a rule for service 'svc:bar' to forward TCP traffic to IPv4 endpoint
runner.EnsurePortMapRuleForSvc("bar", "tailscale0", ipv4, pmTCP) runner.EnsurePortMapRuleForSvc("svc:bar", "tailscale0", ipv4, pmTCP)
svcChains(t, 3, conn) svcChains(t, 3, conn)
chainRuleCount(t, "bar", 1, conn, nftables.TableFamilyIPv4) chainRuleCount(t, "svc:bar", 1, conn, nftables.TableFamilyIPv4)
checkPortMapRule(t, "bar", ipv4, pmTCP, runner, nftables.TableFamilyIPv4) checkPortMapRule(t, "svc:bar", ipv4, pmTCP, runner, nftables.TableFamilyIPv4)
// Create a rule for service 'bar' to forward TCP traffic to an IPv6 endpoint // Create a rule for service 'svc:bar' to forward TCP traffic to an IPv6 endpoint
runner.EnsurePortMapRuleForSvc("bar", "tailscale0", ipv6, pmTCP) runner.EnsurePortMapRuleForSvc("svc:bar", "tailscale0", ipv6, pmTCP)
svcChains(t, 4, conn) svcChains(t, 4, conn)
chainRuleCount(t, "bar", 1, conn, nftables.TableFamilyIPv6) chainRuleCount(t, "svc:bar", 1, conn, nftables.TableFamilyIPv6)
checkPortMapRule(t, "bar", ipv6, pmTCP, runner, nftables.TableFamilyIPv6) checkPortMapRule(t, "svc:bar", ipv6, pmTCP, runner, nftables.TableFamilyIPv6)
// Delete service bar // Delete service svc:bar
runner.DeleteSvc("bar", "tailscale0", []netip.Addr{ipv4, ipv6}, []PortMap{pmTCP}) runner.DeleteSvc("svc:bar", "tailscale0", []netip.Addr{ipv4, ipv6}, []PortMap{pmTCP})
svcChains(t, 2, conn) svcChains(t, 2, conn)
// Delete a rule from service foo // Delete a rule from service svc:foo
runner.DeletePortMapRuleForSvc("foo", "tailscale0", ipv4, pmTCP) runner.DeletePortMapRuleForSvc("svc:foo", "tailscale0", ipv4, pmTCP)
svcChains(t, 2, conn) svcChains(t, 2, conn)
chainRuleCount(t, "foo", 1, conn, nftables.TableFamilyIPv4) chainRuleCount(t, "svc:foo", 1, conn, nftables.TableFamilyIPv4)
// Delete service foo // Delete service svc:foo
runner.DeleteSvc("foo", "tailscale0", []netip.Addr{ipv4, ipv6}, []PortMap{pmTCP, pmTCP1}) runner.DeleteSvc("svc:foo", "tailscale0", []netip.Addr{ipv4, ipv6}, []PortMap{pmTCP, pmTCP1})
svcChains(t, 0, conn) svcChains(t, 0, conn)
} }
func Test_nftablesRunner_EnsureDNATRuleForSvc(t *testing.T) {
conn := newSysConn(t)
runner := newFakeNftablesRunnerWithConn(t, conn, true)
// Test IPv4 DNAT rule
ipv4OrigDst := netip.MustParseAddr("10.0.0.1")
ipv4Target := netip.MustParseAddr("10.0.0.2")
// Create DNAT rule for service 'svc:foo' to forward IPv4 traffic
err := runner.EnsureDNATRuleForSvc("svc:foo", ipv4OrigDst, ipv4Target)
if err != nil {
t.Fatalf("error creating IPv4 DNAT rule: %v", err)
}
checkDNATRule(t, "svc:foo", ipv4OrigDst, ipv4Target, runner, nftables.TableFamilyIPv4)
// Test IPv6 DNAT rule
ipv6OrigDst := netip.MustParseAddr("fd7a:115c:a1e0::1")
ipv6Target := netip.MustParseAddr("fd7a:115c:a1e0::2")
// Create DNAT rule for service 'svc:foo' to forward IPv6 traffic
err = runner.EnsureDNATRuleForSvc("svc:foo", ipv6OrigDst, ipv6Target)
if err != nil {
t.Fatalf("error creating IPv6 DNAT rule: %v", err)
}
checkDNATRule(t, "svc:foo", ipv6OrigDst, ipv6Target, runner, nftables.TableFamilyIPv6)
// Test creating rule for another service
err = runner.EnsureDNATRuleForSvc("svc:bar", ipv4OrigDst, ipv4Target)
if err != nil {
t.Fatalf("error creating DNAT rule for service 'svc:bar': %v", err)
}
checkDNATRule(t, "svc:bar", ipv4OrigDst, ipv4Target, runner, nftables.TableFamilyIPv4)
}
func Test_nftablesRunner_DeleteDNATRuleForSvc(t *testing.T) {
conn := newSysConn(t)
runner := newFakeNftablesRunnerWithConn(t, conn, true)
// Test IPv4 DNAT rule deletion
ipv4OrigDst := netip.MustParseAddr("10.0.0.1")
ipv4Target := netip.MustParseAddr("10.0.0.2")
// Create and then delete IPv4 DNAT rule
err := runner.EnsureDNATRuleForSvc("svc:foo", ipv4OrigDst, ipv4Target)
if err != nil {
t.Fatalf("error creating IPv4 DNAT rule: %v", err)
}
// Verify rule exists before deletion
table, err := runner.getNFTByAddr(ipv4OrigDst)
if err != nil {
t.Fatalf("error getting table: %v", err)
}
nftTable, err := getTableIfExists(runner.conn, table.Proto, "nat")
if err != nil {
t.Fatalf("error getting nat table: %v", err)
}
ch, err := getChainFromTable(runner.conn, nftTable, "PREROUTING")
if err != nil {
t.Fatalf("error getting PREROUTING chain: %v", err)
}
meta := svcRuleMeta("svc:foo", ipv4OrigDst, ipv4Target)
rule, err := runner.findRuleByMetadata(nftTable, ch, meta)
if err != nil {
t.Fatalf("error checking if rule exists: %v", err)
}
if rule == nil {
t.Fatal("rule does not exist before deletion")
}
err = runner.DeleteDNATRuleForSvc("svc:foo", ipv4OrigDst, ipv4Target)
if err != nil {
t.Fatalf("error deleting IPv4 DNAT rule: %v", err)
}
// Verify rule is deleted
rule, err = runner.findRuleByMetadata(nftTable, ch, meta)
if err != nil {
t.Fatalf("error checking if rule exists: %v", err)
}
if rule != nil {
t.Fatal("rule still exists after deletion")
}
// Test IPv6 DNAT rule deletion
ipv6OrigDst := netip.MustParseAddr("fd7a:115c:a1e0::1")
ipv6Target := netip.MustParseAddr("fd7a:115c:a1e0::2")
// Create and then delete IPv6 DNAT rule
err = runner.EnsureDNATRuleForSvc("svc:foo", ipv6OrigDst, ipv6Target)
if err != nil {
t.Fatalf("error creating IPv6 DNAT rule: %v", err)
}
// Verify rule exists before deletion
table, err = runner.getNFTByAddr(ipv6OrigDst)
if err != nil {
t.Fatalf("error getting table: %v", err)
}
nftTable, err = getTableIfExists(runner.conn, table.Proto, "nat")
if err != nil {
t.Fatalf("error getting nat table: %v", err)
}
ch, err = getChainFromTable(runner.conn, nftTable, "PREROUTING")
if err != nil {
t.Fatalf("error getting PREROUTING chain: %v", err)
}
meta = svcRuleMeta("svc:foo", ipv6OrigDst, ipv6Target)
rule, err = runner.findRuleByMetadata(nftTable, ch, meta)
if err != nil {
t.Fatalf("error checking if rule exists: %v", err)
}
if rule == nil {
t.Fatal("rule does not exist before deletion")
}
err = runner.DeleteDNATRuleForSvc("svc:foo", ipv6OrigDst, ipv6Target)
if err != nil {
t.Fatalf("error deleting IPv6 DNAT rule: %v", err)
}
// Verify rule is deleted
rule, err = runner.findRuleByMetadata(nftTable, ch, meta)
if err != nil {
t.Fatalf("error checking if rule exists: %v", err)
}
if rule != nil {
t.Fatal("rule still exists after deletion")
}
}
// checkDNATRule verifies that a DNAT rule exists for the given service, original destination, and target IP.
func checkDNATRule(t *testing.T, svc string, origDst, targetIP netip.Addr, runner *nftablesRunner, fam nftables.TableFamily) {
t.Helper()
table, err := runner.getNFTByAddr(origDst)
if err != nil {
t.Fatalf("error getting table: %v", err)
}
nftTable, err := getTableIfExists(runner.conn, table.Proto, "nat")
if err != nil {
t.Fatalf("error getting nat table: %v", err)
}
if nftTable == nil {
t.Fatal("nat table not found")
}
ch, err := getChainFromTable(runner.conn, nftTable, "PREROUTING")
if err != nil {
t.Fatalf("error getting PREROUTING chain: %v", err)
}
if ch == nil {
t.Fatal("PREROUTING chain not found")
}
meta := svcRuleMeta(svc, origDst, targetIP)
rule, err := runner.findRuleByMetadata(nftTable, ch, meta)
if err != nil {
t.Fatalf("error checking if rule exists: %v", err)
}
if rule == nil {
t.Fatal("DNAT rule not found")
}
}
// svcChains verifies that the expected number of chains exist (for either IP // svcChains verifies that the expected number of chains exist (for either IP
// family) and that each of them is configured as NAT prerouting chain. // family) and that each of them is configured as NAT prerouting chain.
func svcChains(t *testing.T, wantCount int, conn *nftables.Conn) { func svcChains(t *testing.T, wantCount int, conn *nftables.Conn) {

View File

@ -107,6 +107,12 @@ func (n *nftablesRunner) AddDNATRule(origDst netip.Addr, dst netip.Addr) error {
if err != nil { if err != nil {
return err return err
} }
rule := dnatRuleForChain(nat, preroutingCh, origDst, dst, nil)
n.conn.InsertRule(rule)
return n.conn.Flush()
}
func dnatRuleForChain(t *nftables.Table, ch *nftables.Chain, origDst, dst netip.Addr, meta []byte) *nftables.Rule {
var daddrOffset, fam, dadderLen uint32 var daddrOffset, fam, dadderLen uint32
if origDst.Is4() { if origDst.Is4() {
daddrOffset = 16 daddrOffset = 16
@ -117,9 +123,9 @@ func (n *nftablesRunner) AddDNATRule(origDst netip.Addr, dst netip.Addr) error {
dadderLen = 16 dadderLen = 16
fam = unix.NFPROTO_IPV6 fam = unix.NFPROTO_IPV6
} }
dnatRule := &nftables.Rule{ rule := &nftables.Rule{
Table: nat, Table: t,
Chain: preroutingCh, Chain: ch,
Exprs: []expr.Any{ Exprs: []expr.Any{
&expr.Payload{ &expr.Payload{
DestRegister: 1, DestRegister: 1,
@ -143,8 +149,10 @@ func (n *nftablesRunner) AddDNATRule(origDst netip.Addr, dst netip.Addr) error {
}, },
}, },
} }
n.conn.InsertRule(dnatRule) if len(meta) > 0 {
return n.conn.Flush() rule.UserData = meta
}
return rule
} }
// DNATWithLoadBalancer currently just forwards all traffic destined for origDst // DNATWithLoadBalancer currently just forwards all traffic destined for origDst
@ -555,6 +563,8 @@ type NetfilterRunner interface {
EnsurePortMapRuleForSvc(svc, tun string, targetIP netip.Addr, pm PortMap) error EnsurePortMapRuleForSvc(svc, tun string, targetIP netip.Addr, pm PortMap) error
DeletePortMapRuleForSvc(svc, tun string, targetIP netip.Addr, pm PortMap) error DeletePortMapRuleForSvc(svc, tun string, targetIP netip.Addr, pm PortMap) error
EnsureDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error
DeleteDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error
DeleteSvc(svc, tun string, targetIPs []netip.Addr, pm []PortMap) error DeleteSvc(svc, tun string, targetIPs []netip.Addr, pm []PortMap) error

View File

@ -557,6 +557,14 @@ func (n *fakeIPTablesRunner) ClampMSSToPMTU(tun string, addr netip.Addr) error {
return errors.New("not implemented") return errors.New("not implemented")
} }
func (n *fakeIPTablesRunner) EnsureDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error {
return errors.New("not implemented")
}
func (n *fakeIPTablesRunner) DeleteDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error {
return errors.New("not implemented")
}
func (n *fakeIPTablesRunner) addBase4(tunname string) error { func (n *fakeIPTablesRunner) addBase4(tunname string) error {
curIPT := n.ipt4 curIPT := n.ipt4
newRules := []struct{ chain, rule string }{ newRules := []struct{ chain, rule string }{