mirror of
https://github.com/tailscale/tailscale.git
synced 2024-11-25 11:05:45 +00:00
cmd/containerboot,util/linuxfw: create a SNAT rule for dst/src only once, clean up if needed (#13658)
The AddSNATRuleForDst rule was adding a new rule each time it was called including: - if a rule already existed - if a rule matching the destination, but with different desired source already existed This was causing issues especially for the in-progress egress HA proxies work, where the rules are now refreshed more frequently, so more redundant rules were being created. This change: - only creates the rule if it doesn't already exist - if a rule for the same dst, but different source is found, delete it - also ensures that egress proxies refresh firewall rules if the node's tailnet IP changes Updates tailscale/tailscale#13406 Signed-off-by: Irbe Krumina <irbe@tailscale.com>
This commit is contained in:
parent
a3c6a3a34f
commit
9bd158cc09
@ -117,7 +117,7 @@ func installEgressForwardingRule(_ context.Context, dstStr string, tsIPs []netip
|
|||||||
if err := nfr.DNATNonTailscaleTraffic("tailscale0", dst); err != nil {
|
if err := nfr.DNATNonTailscaleTraffic("tailscale0", dst); err != nil {
|
||||||
return fmt.Errorf("installing egress proxy rules: %w", err)
|
return fmt.Errorf("installing egress proxy rules: %w", err)
|
||||||
}
|
}
|
||||||
if err := nfr.AddSNATRuleForDst(local, dst); err != nil {
|
if err := nfr.EnsureSNATForDst(local, dst); err != nil {
|
||||||
return fmt.Errorf("installing egress proxy rules: %w", err)
|
return fmt.Errorf("installing egress proxy rules: %w", err)
|
||||||
}
|
}
|
||||||
if err := nfr.ClampMSSToPMTU("tailscale0", dst); err != nil {
|
if err := nfr.ClampMSSToPMTU("tailscale0", dst); err != nil {
|
||||||
|
@ -481,7 +481,11 @@ func main() {
|
|||||||
egressAddrs = node.Addresses().AsSlice()
|
egressAddrs = node.Addresses().AsSlice()
|
||||||
newCurentEgressIPs = deephash.Hash(&egressAddrs)
|
newCurentEgressIPs = deephash.Hash(&egressAddrs)
|
||||||
egressIPsHaveChanged = newCurentEgressIPs != currentEgressIPs
|
egressIPsHaveChanged = newCurentEgressIPs != currentEgressIPs
|
||||||
if egressIPsHaveChanged && len(egressAddrs) != 0 {
|
// The firewall rules get (re-)installed:
|
||||||
|
// - on startup
|
||||||
|
// - when the tailnet IPs of the tailnet target have changed
|
||||||
|
// - when the tailnet IPs of this node have changed
|
||||||
|
if (egressIPsHaveChanged || ipsHaveChanged) && len(egressAddrs) != 0 {
|
||||||
var rulesInstalled bool
|
var rulesInstalled bool
|
||||||
for _, egressAddr := range egressAddrs {
|
for _, egressAddr := range egressAddrs {
|
||||||
ea := egressAddr.Addr()
|
ea := egressAddr.Addr()
|
||||||
|
@ -196,8 +196,7 @@ func (ep *egressProxy) syncEgressConfigs(cfgs *egressservices.Configs, status *e
|
|||||||
if !local.IsValid() {
|
if !local.IsValid() {
|
||||||
return nil, fmt.Errorf("no valid local IP: %v", local)
|
return nil, fmt.Errorf("no valid local IP: %v", local)
|
||||||
}
|
}
|
||||||
// TODO(irbekrm): only create the SNAT rule if it does not already exist.
|
if err := ep.nfr.EnsureSNATForDst(local, t); err != nil {
|
||||||
if err := ep.nfr.AddSNATRuleForDst(local, t); err != nil {
|
|
||||||
return nil, fmt.Errorf("error setting up SNAT rule: %w", err)
|
return nil, fmt.Errorf("error setting up SNAT rule: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -9,6 +9,7 @@
|
|||||||
"bytes"
|
"bytes"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
@ -371,9 +372,42 @@ func (i *iptablesRunner) AddDNATRule(origDst, dst netip.Addr) error {
|
|||||||
return table.Insert("nat", "PREROUTING", 1, "--destination", origDst.String(), "-j", "DNAT", "--to-destination", dst.String())
|
return table.Insert("nat", "PREROUTING", 1, "--destination", origDst.String(), "-j", "DNAT", "--to-destination", dst.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *iptablesRunner) AddSNATRuleForDst(src, dst netip.Addr) error {
|
// EnsureSNATForDst sets up firewall to ensure that all traffic aimed for dst, has its source ip set to src:
|
||||||
|
// - creates a SNAT rule if not already present
|
||||||
|
// - ensures that any no longer valid SNAT rules for the same dst are removed
|
||||||
|
func (i *iptablesRunner) EnsureSNATForDst(src, dst netip.Addr) error {
|
||||||
table := i.getIPTByAddr(dst)
|
table := i.getIPTByAddr(dst)
|
||||||
return table.Insert("nat", "POSTROUTING", 1, "--destination", dst.String(), "-j", "SNAT", "--to-source", src.String())
|
rules, err := table.List("nat", "POSTROUTING")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error listing rules: %v", err)
|
||||||
|
}
|
||||||
|
// iptables accept either address or a CIDR value for the --destination flag, but converts an address to /32
|
||||||
|
// CIDR. Explicitly passing a /32 CIDR made it possible to test this rule.
|
||||||
|
dstPrefix, err := dst.Prefix(32)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error calculating prefix of dst %v: %v", dst, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// wantsArgsPrefix is the prefix of the SNAT rule for the provided destination.
|
||||||
|
// We should only have one POSTROUTING rule with this prefix.
|
||||||
|
wantsArgsPrefix := fmt.Sprintf("-d %s -j SNAT --to-source", dstPrefix.String())
|
||||||
|
// wantsArgs is the actual SNAT rule that we want.
|
||||||
|
wantsArgs := fmt.Sprintf("%s %s", wantsArgsPrefix, src.String())
|
||||||
|
for _, r := range rules {
|
||||||
|
args := argsFromPostRoutingRule(r)
|
||||||
|
if strings.HasPrefix(args, wantsArgsPrefix) {
|
||||||
|
if strings.HasPrefix(args, wantsArgs) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// SNAT rule matching the destination, but for a different source - delete.
|
||||||
|
if err := table.Delete("nat", "POSTROUTING", strings.Split(args, " ")...); err != nil {
|
||||||
|
// If we failed to delete don't crash the node- the proxy should still be functioning.
|
||||||
|
log.Printf("[unexpected] error deleting rule %s: %v, please report it.", r, err)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return table.Insert("nat", "POSTROUTING", 1, "-d", dstPrefix.String(), "-j", "SNAT", "--to-source", src.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *iptablesRunner) DNATNonTailscaleTraffic(tun string, dst netip.Addr) error {
|
func (i *iptablesRunner) DNATNonTailscaleTraffic(tun string, dst netip.Addr) error {
|
||||||
@ -731,3 +765,10 @@ func clearRules(proto iptables.Protocol, logf logger.Logf) error {
|
|||||||
|
|
||||||
return multierr.New(errs...)
|
return multierr.New(errs...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// argsFromPostRoutingRule accepts a rule as returned by iptables.List and, if it is a rule from POSTROUTING chain,
|
||||||
|
// returns the args part, else returns the original rule.
|
||||||
|
func argsFromPostRoutingRule(r string) string {
|
||||||
|
args, _ := strings.CutPrefix(r, "-A POSTROUTING ")
|
||||||
|
return args
|
||||||
|
}
|
||||||
|
@ -289,3 +289,77 @@ func TestAddAndDelSNATRule(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEnsureSNATForDst_ipt(t *testing.T) {
|
||||||
|
ip1, ip2, ip3 := netip.MustParseAddr("100.99.99.99"), netip.MustParseAddr("100.88.88.88"), netip.MustParseAddr("100.77.77.77")
|
||||||
|
iptr := NewFakeIPTablesRunner()
|
||||||
|
|
||||||
|
// 1. A new rule gets added
|
||||||
|
mustCreateSNATRule_ipt(t, iptr, ip1, ip2)
|
||||||
|
checkSNATRule_ipt(t, iptr, ip1, ip2)
|
||||||
|
checkSNATRuleCount(t, iptr, ip1, 1)
|
||||||
|
|
||||||
|
// 2. Another call to EnsureSNATForDst with the same src and dst does not result in another rule being added.
|
||||||
|
mustCreateSNATRule_ipt(t, iptr, ip1, ip2)
|
||||||
|
checkSNATRule_ipt(t, iptr, ip1, ip2)
|
||||||
|
checkSNATRuleCount(t, iptr, ip1, 1) // still just 1 rule
|
||||||
|
|
||||||
|
// 3. Another call to EnsureSNATForDst with a different src and the same dst results in the earlier rule being
|
||||||
|
// deleted.
|
||||||
|
mustCreateSNATRule_ipt(t, iptr, ip3, ip2)
|
||||||
|
checkSNATRule_ipt(t, iptr, ip3, ip2)
|
||||||
|
checkSNATRuleCount(t, iptr, ip1, 1) // still just 1 rule
|
||||||
|
|
||||||
|
// 4. Another call to EnsureSNATForDst with a different dst should not get the earlier rule deleted.
|
||||||
|
mustCreateSNATRule_ipt(t, iptr, ip3, ip1)
|
||||||
|
checkSNATRule_ipt(t, iptr, ip3, ip1)
|
||||||
|
checkSNATRuleCount(t, iptr, ip1, 2) // now 2 rules
|
||||||
|
|
||||||
|
// 5. A call to EnsureSNATForDst with a match dst and a match port should not get deleted by EnsureSNATForDst for the same dst.
|
||||||
|
args := []string{"--destination", ip1.String(), "-j", "SNAT", "--to-source", "10.0.0.1"}
|
||||||
|
if err := iptr.getIPTByAddr(ip1).Insert("nat", "POSTROUTING", 1, args...); err != nil {
|
||||||
|
t.Fatalf("error adding SNAT rule: %v", err)
|
||||||
|
}
|
||||||
|
exists, err := iptr.getIPTByAddr(ip1).Exists("nat", "POSTROUTING", args...)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error checking if rule exists: %v", err)
|
||||||
|
}
|
||||||
|
if !exists {
|
||||||
|
t.Fatalf("SNAT rule for destination and port unexpectedly deleted")
|
||||||
|
}
|
||||||
|
mustCreateSNATRule_ipt(t, iptr, ip3, ip1)
|
||||||
|
checkSNATRuleCount(t, iptr, ip1, 3) // now 3 rules
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustCreateSNATRule_ipt(t *testing.T, iptr *iptablesRunner, src, dst netip.Addr) {
|
||||||
|
t.Helper()
|
||||||
|
if err := iptr.EnsureSNATForDst(src, dst); err != nil {
|
||||||
|
t.Fatalf("error ensuring SNAT rule: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkSNATRule_ipt(t *testing.T, iptr *iptablesRunner, src, dst netip.Addr) {
|
||||||
|
t.Helper()
|
||||||
|
dstPrefix, err := dst.Prefix(32)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error converting addr to prefix: %v", err)
|
||||||
|
}
|
||||||
|
exists, err := iptr.getIPTByAddr(src).Exists("nat", "POSTROUTING", "-d", dstPrefix.String(), "-j", "SNAT", "--to-source", src.String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error checking if rule exists: %v", err)
|
||||||
|
}
|
||||||
|
if !exists {
|
||||||
|
t.Fatalf("SNAT rule for src %s dst %s should exist, but it does not", src, dst)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkSNATRuleCount(t *testing.T, iptr *iptablesRunner, ip netip.Addr, wantsRules int) {
|
||||||
|
t.Helper()
|
||||||
|
rules, err := iptr.getIPTByAddr(ip).List("nat", "POSTROUTING")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error listing rules: %v", err)
|
||||||
|
}
|
||||||
|
if len(rules) != wantsRules {
|
||||||
|
t.Fatalf("wants %d rules, got %d", wantsRules, len(rules))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -27,32 +27,32 @@ func Test_nftablesRunner_EnsurePortMapRuleForSvc(t *testing.T) {
|
|||||||
runner.EnsurePortMapRuleForSvc("foo", "tailscale0", ipv4, pmTCP)
|
runner.EnsurePortMapRuleForSvc("foo", "tailscale0", ipv4, pmTCP)
|
||||||
svcChains(t, 1, conn)
|
svcChains(t, 1, conn)
|
||||||
chainRuleCount(t, "foo", 1, conn, nftables.TableFamilyIPv4)
|
chainRuleCount(t, "foo", 1, conn, nftables.TableFamilyIPv4)
|
||||||
chainRule(t, "foo", ipv4, pmTCP, runner, nftables.TableFamilyIPv4)
|
checkPortMapRule(t, "foo", ipv4, pmTCP, runner, nftables.TableFamilyIPv4)
|
||||||
|
|
||||||
// Create another rule for service 'foo' to forward TCP traffic to the
|
// Create another rule for service 'foo' to forward TCP traffic to the
|
||||||
// same IPv4 endpoint, but to a different port.
|
// same IPv4 endpoint, but to a different port.
|
||||||
runner.EnsurePortMapRuleForSvc("foo", "tailscale0", ipv4, pmTCP1)
|
runner.EnsurePortMapRuleForSvc("foo", "tailscale0", ipv4, pmTCP1)
|
||||||
svcChains(t, 1, conn)
|
svcChains(t, 1, conn)
|
||||||
chainRuleCount(t, "foo", 2, conn, nftables.TableFamilyIPv4)
|
chainRuleCount(t, "foo", 2, conn, nftables.TableFamilyIPv4)
|
||||||
chainRule(t, "foo", ipv4, pmTCP1, runner, nftables.TableFamilyIPv4)
|
checkPortMapRule(t, "foo", ipv4, pmTCP1, runner, nftables.TableFamilyIPv4)
|
||||||
|
|
||||||
// Create a rule for service 'foo' to forward TCP traffic to an IPv6 endpoint
|
// Create a rule for service 'foo' to forward TCP traffic to an IPv6 endpoint
|
||||||
runner.EnsurePortMapRuleForSvc("foo", "tailscale0", ipv6, pmTCP)
|
runner.EnsurePortMapRuleForSvc("foo", "tailscale0", ipv6, pmTCP)
|
||||||
svcChains(t, 2, conn)
|
svcChains(t, 2, conn)
|
||||||
chainRuleCount(t, "foo", 1, conn, nftables.TableFamilyIPv6)
|
chainRuleCount(t, "foo", 1, conn, nftables.TableFamilyIPv6)
|
||||||
chainRule(t, "foo", ipv6, pmTCP, runner, nftables.TableFamilyIPv6)
|
checkPortMapRule(t, "foo", ipv6, pmTCP, runner, nftables.TableFamilyIPv6)
|
||||||
|
|
||||||
// Create a rule for service 'bar' to forward TCP traffic to IPv4 endpoint
|
// Create a rule for service 'bar' to forward TCP traffic to IPv4 endpoint
|
||||||
runner.EnsurePortMapRuleForSvc("bar", "tailscale0", ipv4, pmTCP)
|
runner.EnsurePortMapRuleForSvc("bar", "tailscale0", ipv4, pmTCP)
|
||||||
svcChains(t, 3, conn)
|
svcChains(t, 3, conn)
|
||||||
chainRuleCount(t, "bar", 1, conn, nftables.TableFamilyIPv4)
|
chainRuleCount(t, "bar", 1, conn, nftables.TableFamilyIPv4)
|
||||||
chainRule(t, "bar", ipv4, pmTCP, runner, nftables.TableFamilyIPv4)
|
checkPortMapRule(t, "bar", ipv4, pmTCP, runner, nftables.TableFamilyIPv4)
|
||||||
|
|
||||||
// Create a rule for service 'bar' to forward TCP traffic to an IPv6 endpoint
|
// Create a rule for service 'bar' to forward TCP traffic to an IPv6 endpoint
|
||||||
runner.EnsurePortMapRuleForSvc("bar", "tailscale0", ipv6, pmTCP)
|
runner.EnsurePortMapRuleForSvc("bar", "tailscale0", ipv6, pmTCP)
|
||||||
svcChains(t, 4, conn)
|
svcChains(t, 4, conn)
|
||||||
chainRuleCount(t, "bar", 1, conn, nftables.TableFamilyIPv6)
|
chainRuleCount(t, "bar", 1, conn, nftables.TableFamilyIPv6)
|
||||||
chainRule(t, "bar", ipv6, pmTCP, runner, nftables.TableFamilyIPv6)
|
checkPortMapRule(t, "bar", ipv6, pmTCP, runner, nftables.TableFamilyIPv6)
|
||||||
|
|
||||||
// Delete service bar
|
// Delete service bar
|
||||||
runner.DeleteSvc("bar", "tailscale0", []netip.Addr{ipv4, ipv6}, []PortMap{pmTCP})
|
runner.DeleteSvc("bar", "tailscale0", []netip.Addr{ipv4, ipv6}, []PortMap{pmTCP})
|
||||||
@ -95,36 +95,26 @@ func svcChains(t *testing.T, wantCount int, conn *nftables.Conn) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// chainRuleCount returns number of rules in a chain identified by service name and IP family.
|
// chainRuleCount verifies that the named chain in the given table contains the provided number of rules.
|
||||||
func chainRuleCount(t *testing.T, svc string, count int, conn *nftables.Conn, fam nftables.TableFamily) {
|
func chainRuleCount(t *testing.T, name string, numOfRules int, conn *nftables.Conn, fam nftables.TableFamily) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
chains, err := conn.ListChainsOfTableFamily(fam)
|
chains, err := conn.ListChainsOfTableFamily(fam)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("error listing chains: %v", err)
|
t.Fatalf("error listing chains: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
found := false
|
|
||||||
for _, ch := range chains {
|
for _, ch := range chains {
|
||||||
if ch.Name == svc {
|
if ch.Name == name {
|
||||||
found = true
|
checkChainRules(t, conn, ch, numOfRules)
|
||||||
rules, err := conn.GetRules(ch.Table, ch)
|
return
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("error getting rules: %v", err)
|
|
||||||
}
|
|
||||||
if len(rules) != count {
|
|
||||||
t.Fatalf("unexpected number of rules, wants %d got %d", count, len(rules))
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !found {
|
t.Fatalf("chain %s does not exist", name)
|
||||||
t.Fatalf("chain for service %s does not exist", svc)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// chainRule verifies that rule for the provided target IP and PortMap exists in
|
// checkPortMapRule verifies that rule for the provided target IP and PortMap exists in a chain identified by service
|
||||||
// a chain identified by service name and IP family.
|
// name and IP family.
|
||||||
func chainRule(t *testing.T, svc string, targetIP netip.Addr, pm PortMap, runner *nftablesRunner, fam nftables.TableFamily) {
|
func checkPortMapRule(t *testing.T, svc string, targetIP netip.Addr, pm PortMap, runner *nftablesRunner, fam nftables.TableFamily) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
chains, err := runner.conn.ListChainsOfTableFamily(fam)
|
chains, err := runner.conn.ListChainsOfTableFamily(fam)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -146,11 +136,17 @@ func chainRule(t *testing.T, svc string, targetIP netip.Addr, pm PortMap, runner
|
|||||||
t.Fatalf("error converting protocol: %v", err)
|
t.Fatalf("error converting protocol: %v", err)
|
||||||
}
|
}
|
||||||
wantsRule := portMapRule(chain.Table, chain, "tailscale0", targetIP, pm.MatchPort, pm.TargetPort, p, meta)
|
wantsRule := portMapRule(chain.Table, chain, "tailscale0", targetIP, pm.MatchPort, pm.TargetPort, p, meta)
|
||||||
gotRule, err := findRule(runner.conn, wantsRule)
|
checkRule(t, wantsRule, runner.conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkRule checks that the provided rules exists.
|
||||||
|
func checkRule(t *testing.T, rule *nftables.Rule, conn *nftables.Conn) {
|
||||||
|
t.Helper()
|
||||||
|
gotRule, err := findRule(conn, rule)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("error looking up rule: %v", err)
|
t.Fatalf("error looking up rule: %v", err)
|
||||||
}
|
}
|
||||||
if gotRule == nil {
|
if gotRule == nil {
|
||||||
t.Fatalf("rule not found")
|
t.Fatal("rule not found")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -193,7 +193,7 @@ func (n *nftablesRunner) DNATNonTailscaleTraffic(tunname string, dst netip.Addr)
|
|||||||
return n.conn.Flush()
|
return n.conn.Flush()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *nftablesRunner) AddSNATRuleForDst(src, dst netip.Addr) error {
|
func (n *nftablesRunner) EnsureSNATForDst(src, dst netip.Addr) error {
|
||||||
polAccept := nftables.ChainPolicyAccept
|
polAccept := nftables.ChainPolicyAccept
|
||||||
table, err := n.getNFTByAddr(dst)
|
table, err := n.getNFTByAddr(dst)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -216,44 +216,26 @@ func (n *nftablesRunner) AddSNATRuleForDst(src, dst netip.Addr) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error ensuring postrouting chain: %w", err)
|
return fmt.Errorf("error ensuring postrouting chain: %w", err)
|
||||||
}
|
}
|
||||||
var daddrOffset, fam, daddrLen uint32
|
|
||||||
if dst.Is4() {
|
|
||||||
daddrOffset = 16
|
|
||||||
daddrLen = 4
|
|
||||||
fam = unix.NFPROTO_IPV4
|
|
||||||
} else {
|
|
||||||
daddrOffset = 24
|
|
||||||
daddrLen = 16
|
|
||||||
fam = unix.NFPROTO_IPV6
|
|
||||||
}
|
|
||||||
|
|
||||||
snatRule := &nftables.Rule{
|
rules, err := n.conn.GetRules(nat, postRoutingCh)
|
||||||
Table: nat,
|
if err != nil {
|
||||||
Chain: postRoutingCh,
|
return fmt.Errorf("error listing rules: %w", err)
|
||||||
Exprs: []expr.Any{
|
|
||||||
&expr.Payload{
|
|
||||||
DestRegister: 1,
|
|
||||||
Base: expr.PayloadBaseNetworkHeader,
|
|
||||||
Offset: daddrOffset,
|
|
||||||
Len: daddrLen,
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: dst.AsSlice(),
|
|
||||||
},
|
|
||||||
&expr.Immediate{
|
|
||||||
Register: 1,
|
|
||||||
Data: src.AsSlice(),
|
|
||||||
},
|
|
||||||
&expr.NAT{
|
|
||||||
Type: expr.NATTypeSourceNAT,
|
|
||||||
Family: fam,
|
|
||||||
RegAddrMin: 1,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
n.conn.AddRule(snatRule)
|
snatRulePrefixMatch := fmt.Sprintf("dst:%s,src:", dst.String())
|
||||||
|
snatRuleFullMatch := fmt.Sprintf("%s%s", snatRulePrefixMatch, src.String())
|
||||||
|
for _, rule := range rules {
|
||||||
|
current := string(rule.UserData)
|
||||||
|
if strings.HasPrefix(string(rule.UserData), snatRulePrefixMatch) {
|
||||||
|
if strings.EqualFold(current, snatRuleFullMatch) {
|
||||||
|
return nil // already exists, do nothing
|
||||||
|
}
|
||||||
|
if err := n.conn.DelRule(rule); err != nil {
|
||||||
|
return fmt.Errorf("error deleting SNAT rule: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rule := snatRule(nat, postRoutingCh, src, dst, []byte(snatRuleFullMatch))
|
||||||
|
n.conn.AddRule(rule)
|
||||||
return n.conn.Flush()
|
return n.conn.Flush()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -557,11 +539,12 @@ type NetfilterRunner interface {
|
|||||||
// in the Kubernetes ingress proxies.
|
// in the Kubernetes ingress proxies.
|
||||||
DNATWithLoadBalancer(origDst netip.Addr, dsts []netip.Addr) error
|
DNATWithLoadBalancer(origDst netip.Addr, dsts []netip.Addr) error
|
||||||
|
|
||||||
// AddSNATRuleForDst adds a rule to the nat/POSTROUTING chain to SNAT
|
// EnsureSNATForDst sets up firewall to mask the source for traffic destined for dst to src:
|
||||||
// traffic destined for dst to src.
|
// - creates a SNAT rule if it doesn't already exist
|
||||||
|
// - deletes any pre-existing rules matching the destination
|
||||||
// This is used to forward traffic destined for the local machine over
|
// This is used to forward traffic destined for the local machine over
|
||||||
// the Tailscale interface, as used in the Kubernetes egress proxies.
|
// the Tailscale interface, as used in the Kubernetes egress proxies.
|
||||||
AddSNATRuleForDst(src, dst netip.Addr) error
|
EnsureSNATForDst(src, dst netip.Addr) error
|
||||||
|
|
||||||
// DNATNonTailscaleTraffic adds a rule to the nat/PREROUTING chain to DNAT
|
// DNATNonTailscaleTraffic adds a rule to the nat/PREROUTING chain to DNAT
|
||||||
// all traffic inbound from any interface except exemptInterface to dst.
|
// all traffic inbound from any interface except exemptInterface to dst.
|
||||||
@ -2028,3 +2011,45 @@ func NfTablesCleanUp(logf logger.Logf) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func snatRule(t *nftables.Table, ch *nftables.Chain, src, dst netip.Addr, meta []byte) *nftables.Rule {
|
||||||
|
var daddrOffset, fam, daddrLen uint32
|
||||||
|
if dst.Is4() {
|
||||||
|
daddrOffset = 16
|
||||||
|
daddrLen = 4
|
||||||
|
fam = unix.NFPROTO_IPV4
|
||||||
|
} else {
|
||||||
|
daddrOffset = 24
|
||||||
|
daddrLen = 16
|
||||||
|
fam = unix.NFPROTO_IPV6
|
||||||
|
}
|
||||||
|
|
||||||
|
return &nftables.Rule{
|
||||||
|
Table: t,
|
||||||
|
Chain: ch,
|
||||||
|
Exprs: []expr.Any{
|
||||||
|
&expr.Payload{
|
||||||
|
DestRegister: 1,
|
||||||
|
Base: expr.PayloadBaseNetworkHeader,
|
||||||
|
Offset: daddrOffset,
|
||||||
|
Len: daddrLen,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: dst.AsSlice(),
|
||||||
|
},
|
||||||
|
&expr.Immediate{
|
||||||
|
Register: 1,
|
||||||
|
Data: src.AsSlice(),
|
||||||
|
},
|
||||||
|
&expr.NAT{
|
||||||
|
Type: expr.NATTypeSourceNAT,
|
||||||
|
Family: fam,
|
||||||
|
RegAddrMin: 1,
|
||||||
|
RegAddrMax: 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
UserData: meta,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -954,6 +954,37 @@ func TestPickFirewallModeFromInstalledRules(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This test creates a temporary network namespace for the nftables rules being
|
||||||
|
// set up, so it needs to run in a privileged mode. Locally it needs to be run
|
||||||
|
// by root, else it will be silently skipped. In CI it runs in a privileged
|
||||||
|
// container.
|
||||||
|
func TestEnsureSNATForDst_nftables(t *testing.T) {
|
||||||
|
conn := newSysConn(t)
|
||||||
|
runner := newFakeNftablesRunnerWithConn(t, conn, true)
|
||||||
|
ip1, ip2, ip3 := netip.MustParseAddr("100.99.99.99"), netip.MustParseAddr("100.88.88.88"), netip.MustParseAddr("100.77.77.77")
|
||||||
|
|
||||||
|
// 1. A new rule gets added
|
||||||
|
mustCreateSNATRule_nft(t, runner, ip1, ip2)
|
||||||
|
chainRuleCount(t, "POSTROUTING", 1, conn, nftables.TableFamilyIPv4)
|
||||||
|
checkSNATRule_nft(t, runner, runner.nft4.Proto, ip1, ip2)
|
||||||
|
|
||||||
|
// 2. Another call to EnsureSNATForDst with the same src and dst does not result in another rule being added.
|
||||||
|
mustCreateSNATRule_nft(t, runner, ip1, ip2)
|
||||||
|
chainRuleCount(t, "POSTROUTING", 1, conn, nftables.TableFamilyIPv4) // still just one rule
|
||||||
|
checkSNATRule_nft(t, runner, runner.nft4.Proto, ip1, ip2)
|
||||||
|
|
||||||
|
// 3. Another call to EnsureSNATForDst with a different src and the same dst results in the earlier rule being
|
||||||
|
// deleted.
|
||||||
|
mustCreateSNATRule_nft(t, runner, ip3, ip2)
|
||||||
|
chainRuleCount(t, "POSTROUTING", 1, conn, nftables.TableFamilyIPv4) // still just one rule
|
||||||
|
checkSNATRule_nft(t, runner, runner.nft4.Proto, ip3, ip2)
|
||||||
|
|
||||||
|
// 4. Another call to EnsureSNATForDst with a different dst should not get the earlier rule deleted.
|
||||||
|
mustCreateSNATRule_nft(t, runner, ip3, ip1)
|
||||||
|
chainRuleCount(t, "POSTROUTING", 2, conn, nftables.TableFamilyIPv4) // now two rules
|
||||||
|
checkSNATRule_nft(t, runner, runner.nft4.Proto, ip3, ip1)
|
||||||
|
}
|
||||||
|
|
||||||
func newFakeNftablesRunnerWithConn(t *testing.T, conn *nftables.Conn, hasIPv6 bool) *nftablesRunner {
|
func newFakeNftablesRunnerWithConn(t *testing.T, conn *nftables.Conn, hasIPv6 bool) *nftablesRunner {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
if !hasIPv6 {
|
if !hasIPv6 {
|
||||||
@ -964,3 +995,32 @@ func newFakeNftablesRunnerWithConn(t *testing.T, conn *nftables.Conn, hasIPv6 bo
|
|||||||
}
|
}
|
||||||
return newNfTablesRunnerWithConn(t.Logf, conn)
|
return newNfTablesRunnerWithConn(t.Logf, conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func mustCreateSNATRule_nft(t *testing.T, runner *nftablesRunner, src, dst netip.Addr) {
|
||||||
|
t.Helper()
|
||||||
|
if err := runner.EnsureSNATForDst(src, dst); err != nil {
|
||||||
|
t.Fatalf("error ensuring SNAT rule: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkSNATRule_nft verifies that a SNAT rule for the given destination and source exists.
|
||||||
|
func checkSNATRule_nft(t *testing.T, runner *nftablesRunner, fam nftables.TableFamily, src, dst netip.Addr) {
|
||||||
|
t.Helper()
|
||||||
|
chains, err := runner.conn.ListChainsOfTableFamily(fam)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error listing chains: %v", err)
|
||||||
|
}
|
||||||
|
var chain *nftables.Chain
|
||||||
|
for _, ch := range chains {
|
||||||
|
if ch.Name == "POSTROUTING" {
|
||||||
|
chain = ch
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if chain == nil {
|
||||||
|
t.Fatal("POSTROUTING chain does not exist")
|
||||||
|
}
|
||||||
|
meta := []byte(fmt.Sprintf("dst:%s,src:%s", dst.String(), src.String()))
|
||||||
|
wantsRule := snatRule(chain.Table, chain, src, dst, meta)
|
||||||
|
checkRule(t, wantsRule, runner.conn)
|
||||||
|
}
|
||||||
|
@ -530,7 +530,7 @@ func (n *fakeIPTablesRunner) DNATWithLoadBalancer(netip.Addr, []netip.Addr) erro
|
|||||||
return errors.New("not implemented")
|
return errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *fakeIPTablesRunner) AddSNATRuleForDst(src, dst netip.Addr) error {
|
func (n *fakeIPTablesRunner) EnsureSNATForDst(src, dst netip.Addr) error {
|
||||||
return errors.New("not implemented")
|
return errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user