| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | // Copyright (c) Tailscale Inc & AUTHORS | 
					
						
							|  |  |  | // SPDX-License-Identifier: BSD-3-Clause | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | //go:build linux | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | package linuxfw | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import ( | 
					
						
							|  |  |  | 	"encoding/binary" | 
					
						
							|  |  |  | 	"encoding/hex" | 
					
						
							|  |  |  | 	"errors" | 
					
						
							|  |  |  | 	"fmt" | 
					
						
							|  |  |  | 	"net" | 
					
						
							|  |  |  | 	"net/netip" | 
					
						
							|  |  |  | 	"reflect" | 
					
						
							| 
									
										
										
										
											2023-08-18 18:41:28 +00:00
										 |  |  | 	"strings" | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | 	"github.com/google/nftables" | 
					
						
							|  |  |  | 	"github.com/google/nftables/expr" | 
					
						
							| 
									
										
										
										
											2023-10-11 07:26:40 -07:00
										 |  |  | 	"golang.org/x/sys/unix" | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | 	"tailscale.com/net/tsaddr" | 
					
						
							|  |  |  | 	"tailscale.com/types/logger" | 
					
						
							| 
									
										
										
										
											2023-10-17 19:51:40 +00:00
										 |  |  | 	"tailscale.com/types/ptr" | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | const ( | 
					
						
							|  |  |  | 	chainNameForward     = "ts-forward" | 
					
						
							|  |  |  | 	chainNameInput       = "ts-input" | 
					
						
							|  |  |  | 	chainNamePostrouting = "ts-postrouting" | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-18 18:41:28 +00:00
										 |  |  | // chainTypeRegular is an nftables chain that does not apply to a hook. | 
					
						
							|  |  |  | const chainTypeRegular = "" | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | type chainInfo struct { | 
					
						
							|  |  |  | 	table         *nftables.Table | 
					
						
							|  |  |  | 	name          string | 
					
						
							|  |  |  | 	chainType     nftables.ChainType | 
					
						
							|  |  |  | 	chainHook     *nftables.ChainHook | 
					
						
							|  |  |  | 	chainPriority *nftables.ChainPriority | 
					
						
							| 
									
										
										
										
											2023-08-18 18:41:28 +00:00
										 |  |  | 	chainPolicy   *nftables.ChainPolicy | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-05-14 08:51:53 +01:00
										 |  |  | // nftable contains nat and filter tables for the given IP family (Proto). | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | type nftable struct { | 
					
						
							| 
									
										
										
										
											2024-05-14 08:51:53 +01:00
										 |  |  | 	Proto  nftables.TableFamily // IPv4 or IPv6 | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | 	Filter *nftables.Table | 
					
						
							|  |  |  | 	Nat    *nftables.Table | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-18 18:41:28 +00:00
										 |  |  | // nftablesRunner implements a netfilterRunner using the netlink based nftables | 
					
						
							|  |  |  | // library. As nftables allows for arbitrary tables and chains, there is a need | 
					
						
							|  |  |  | // to follow conventions in order to integrate well with a surrounding | 
					
						
							|  |  |  | // ecosystem. The rules installed by nftablesRunner have the following | 
					
						
							|  |  |  | // properties: | 
					
						
							|  |  |  | //   - Install rules that intend to take precedence over rules installed by | 
					
						
							|  |  |  | //     other software. Tailscale provides packet filtering for tailnet traffic | 
					
						
							|  |  |  | //     inside the daemon based on the tailnet ACL rules. | 
					
						
							|  |  |  | //   - As nftables "accept" is not final, rules from high priority tables (low | 
					
						
							|  |  |  | //     numbers) will fall through to lower priority tables (high numbers). In | 
					
						
							|  |  |  | //     order to effectively be 'final', we install "jump" rules into conventional | 
					
						
							|  |  |  | //     tables and chains that will reach an accept verdict inside those tables. | 
					
						
							|  |  |  | //   - The table and chain conventions followed here are those used by | 
					
						
							|  |  |  | //     `iptables-nft` and `ufw`, so that those tools co-exist and do not | 
					
						
							|  |  |  | //     negatively affect Tailscale function. | 
					
						
							| 
									
										
										
										
											2024-04-02 19:31:33 +01:00
										 |  |  | //   - Be mindful that 1) all chains attached to a given hook (i.e the forward hook) | 
					
						
							|  |  |  | //     will be processed in priority order till either a rule in one of the chains issues a drop verdict | 
					
						
							|  |  |  | //     or there are no more chains for that hook | 
					
						
							|  |  |  | //     2) processing of individual rules within a chain will stop once one of them issues a final verdict (accept, drop). | 
					
						
							|  |  |  | //     https://wiki.nftables.org/wiki-nftables/index.php/Configuring_chains | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | type nftablesRunner struct { | 
					
						
							|  |  |  | 	conn *nftables.Conn | 
					
						
							| 
									
										
										
										
											2024-06-25 14:13:49 +01:00
										 |  |  | 	nft4 *nftable // IPv4 tables, never nil | 
					
						
							|  |  |  | 	nft6 *nftable // IPv6 tables or nil if the system does not support IPv6 | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-05-14 08:51:53 +01:00
										 |  |  | 	v6Available bool // whether the host supports IPv6 | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-10-11 07:26:40 -07:00
										 |  |  | func (n *nftablesRunner) ensurePreroutingChain(dst netip.Addr) (*nftables.Table, *nftables.Chain, error) { | 
					
						
							|  |  |  | 	polAccept := nftables.ChainPolicyAccept | 
					
						
							| 
									
										
										
										
											2024-06-25 14:13:49 +01:00
										 |  |  | 	table, err := n.getNFTByAddr(dst) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return nil, nil, fmt.Errorf("error setting up nftables for IP family of %v: %w", dst, err) | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2023-10-11 07:26:40 -07:00
										 |  |  | 	nat, err := createTableIfNotExist(n.conn, table.Proto, "nat") | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return nil, nil, fmt.Errorf("error ensuring nat table: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// ensure prerouting chain exists | 
					
						
							|  |  |  | 	preroutingCh, err := getOrCreateChain(n.conn, chainInfo{ | 
					
						
							|  |  |  | 		table:         nat, | 
					
						
							|  |  |  | 		name:          "PREROUTING", | 
					
						
							|  |  |  | 		chainType:     nftables.ChainTypeNAT, | 
					
						
							|  |  |  | 		chainHook:     nftables.ChainHookPrerouting, | 
					
						
							|  |  |  | 		chainPriority: nftables.ChainPriorityNATDest, | 
					
						
							|  |  |  | 		chainPolicy:   &polAccept, | 
					
						
							|  |  |  | 	}) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return nil, nil, fmt.Errorf("error ensuring prerouting chain: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return nat, preroutingCh, nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (n *nftablesRunner) AddDNATRule(origDst netip.Addr, dst netip.Addr) error { | 
					
						
							|  |  |  | 	nat, preroutingCh, err := n.ensurePreroutingChain(dst) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return err | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	var daddrOffset, fam, dadderLen uint32 | 
					
						
							|  |  |  | 	if origDst.Is4() { | 
					
						
							|  |  |  | 		daddrOffset = 16 | 
					
						
							|  |  |  | 		dadderLen = 4 | 
					
						
							|  |  |  | 		fam = unix.NFPROTO_IPV4 | 
					
						
							|  |  |  | 	} else { | 
					
						
							|  |  |  | 		daddrOffset = 24 | 
					
						
							|  |  |  | 		dadderLen = 16 | 
					
						
							|  |  |  | 		fam = unix.NFPROTO_IPV6 | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	dnatRule := &nftables.Rule{ | 
					
						
							|  |  |  | 		Table: nat, | 
					
						
							|  |  |  | 		Chain: preroutingCh, | 
					
						
							|  |  |  | 		Exprs: []expr.Any{ | 
					
						
							|  |  |  | 			&expr.Payload{ | 
					
						
							|  |  |  | 				DestRegister: 1, | 
					
						
							|  |  |  | 				Base:         expr.PayloadBaseNetworkHeader, | 
					
						
							|  |  |  | 				Offset:       daddrOffset, | 
					
						
							|  |  |  | 				Len:          dadderLen, | 
					
						
							|  |  |  | 			}, | 
					
						
							|  |  |  | 			&expr.Cmp{ | 
					
						
							|  |  |  | 				Op:       expr.CmpOpEq, | 
					
						
							|  |  |  | 				Register: 1, | 
					
						
							|  |  |  | 				Data:     origDst.AsSlice(), | 
					
						
							|  |  |  | 			}, | 
					
						
							|  |  |  | 			&expr.Immediate{ | 
					
						
							|  |  |  | 				Register: 1, | 
					
						
							|  |  |  | 				Data:     dst.AsSlice(), | 
					
						
							|  |  |  | 			}, | 
					
						
							|  |  |  | 			&expr.NAT{ | 
					
						
							|  |  |  | 				Type:       expr.NATTypeDestNAT, | 
					
						
							|  |  |  | 				Family:     fam, | 
					
						
							|  |  |  | 				RegAddrMin: 1, | 
					
						
							|  |  |  | 			}, | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	n.conn.InsertRule(dnatRule) | 
					
						
							|  |  |  | 	return n.conn.Flush() | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-23 17:30:00 +01:00
										 |  |  | // DNATWithLoadBalancer currently just forwards all traffic destined for origDst | 
					
						
							|  |  |  | // to the first IP address from the backend targets. | 
					
						
							|  |  |  | // TODO (irbekrm): instead of doing this load balance traffic evenly to all | 
					
						
							|  |  |  | // backend destinations. | 
					
						
							|  |  |  | // https://github.com/tailscale/tailscale/commit/d37f2f508509c6c35ad724fd75a27685b90b575b#diff-a3bcbcd1ca198799f4f768dc56fea913e1945a6b3ec9dbec89325a84a19a85e7R148-R232 | 
					
						
							|  |  |  | func (n *nftablesRunner) DNATWithLoadBalancer(origDst netip.Addr, dsts []netip.Addr) error { | 
					
						
							|  |  |  | 	return n.AddDNATRule(origDst, dsts[0]) | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-10-11 07:26:40 -07:00
										 |  |  | func (n *nftablesRunner) DNATNonTailscaleTraffic(tunname string, dst netip.Addr) error { | 
					
						
							|  |  |  | 	nat, preroutingCh, err := n.ensurePreroutingChain(dst) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return err | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	var famConst uint32 | 
					
						
							|  |  |  | 	if dst.Is4() { | 
					
						
							|  |  |  | 		famConst = unix.NFPROTO_IPV4 | 
					
						
							|  |  |  | 	} else { | 
					
						
							|  |  |  | 		famConst = unix.NFPROTO_IPV6 | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	dnatRule := &nftables.Rule{ | 
					
						
							|  |  |  | 		Table: nat, | 
					
						
							|  |  |  | 		Chain: preroutingCh, | 
					
						
							|  |  |  | 		Exprs: []expr.Any{ | 
					
						
							|  |  |  | 			&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, | 
					
						
							|  |  |  | 			&expr.Cmp{ | 
					
						
							|  |  |  | 				Op:       expr.CmpOpNeq, | 
					
						
							|  |  |  | 				Register: 1, | 
					
						
							|  |  |  | 				Data:     []byte(tunname), | 
					
						
							|  |  |  | 			}, | 
					
						
							|  |  |  | 			&expr.Immediate{ | 
					
						
							|  |  |  | 				Register: 1, | 
					
						
							|  |  |  | 				Data:     dst.AsSlice(), | 
					
						
							|  |  |  | 			}, | 
					
						
							|  |  |  | 			&expr.NAT{ | 
					
						
							|  |  |  | 				Type:       expr.NATTypeDestNAT, | 
					
						
							|  |  |  | 				Family:     famConst, | 
					
						
							|  |  |  | 				RegAddrMin: 1, | 
					
						
							|  |  |  | 			}, | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2024-02-29 16:53:43 +00:00
										 |  |  | 	n.conn.InsertRule(dnatRule) | 
					
						
							| 
									
										
										
										
											2023-10-11 07:26:40 -07:00
										 |  |  | 	return n.conn.Flush() | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (n *nftablesRunner) AddSNATRuleForDst(src, dst netip.Addr) error { | 
					
						
							|  |  |  | 	polAccept := nftables.ChainPolicyAccept | 
					
						
							| 
									
										
										
										
											2024-06-25 14:13:49 +01:00
										 |  |  | 	table, err := n.getNFTByAddr(dst) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("error setting up nftables for IP family of %v: %w", dst, err) | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2023-10-11 07:26:40 -07:00
										 |  |  | 	nat, err := createTableIfNotExist(n.conn, table.Proto, "nat") | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("error ensuring nat table exists: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// ensure postrouting chain exists | 
					
						
							|  |  |  | 	postRoutingCh, err := getOrCreateChain(n.conn, chainInfo{ | 
					
						
							|  |  |  | 		table:         nat, | 
					
						
							|  |  |  | 		name:          "POSTROUTING", | 
					
						
							|  |  |  | 		chainType:     nftables.ChainTypeNAT, | 
					
						
							|  |  |  | 		chainHook:     nftables.ChainHookPostrouting, | 
					
						
							|  |  |  | 		chainPriority: nftables.ChainPriorityNATSource, | 
					
						
							|  |  |  | 		chainPolicy:   &polAccept, | 
					
						
							|  |  |  | 	}) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("error ensuring postrouting chain: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	var daddrOffset, fam, daddrLen uint32 | 
					
						
							|  |  |  | 	if dst.Is4() { | 
					
						
							|  |  |  | 		daddrOffset = 16 | 
					
						
							|  |  |  | 		daddrLen = 4 | 
					
						
							|  |  |  | 		fam = unix.NFPROTO_IPV4 | 
					
						
							|  |  |  | 	} else { | 
					
						
							|  |  |  | 		daddrOffset = 24 | 
					
						
							|  |  |  | 		daddrLen = 16 | 
					
						
							|  |  |  | 		fam = unix.NFPROTO_IPV6 | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	snatRule := &nftables.Rule{ | 
					
						
							|  |  |  | 		Table: nat, | 
					
						
							|  |  |  | 		Chain: postRoutingCh, | 
					
						
							|  |  |  | 		Exprs: []expr.Any{ | 
					
						
							|  |  |  | 			&expr.Payload{ | 
					
						
							|  |  |  | 				DestRegister: 1, | 
					
						
							|  |  |  | 				Base:         expr.PayloadBaseNetworkHeader, | 
					
						
							|  |  |  | 				Offset:       daddrOffset, | 
					
						
							|  |  |  | 				Len:          daddrLen, | 
					
						
							|  |  |  | 			}, | 
					
						
							|  |  |  | 			&expr.Cmp{ | 
					
						
							|  |  |  | 				Op:       expr.CmpOpEq, | 
					
						
							|  |  |  | 				Register: 1, | 
					
						
							|  |  |  | 				Data:     dst.AsSlice(), | 
					
						
							|  |  |  | 			}, | 
					
						
							|  |  |  | 			&expr.Immediate{ | 
					
						
							|  |  |  | 				Register: 1, | 
					
						
							|  |  |  | 				Data:     src.AsSlice(), | 
					
						
							|  |  |  | 			}, | 
					
						
							|  |  |  | 			&expr.NAT{ | 
					
						
							|  |  |  | 				Type:       expr.NATTypeSourceNAT, | 
					
						
							|  |  |  | 				Family:     fam, | 
					
						
							|  |  |  | 				RegAddrMin: 1, | 
					
						
							|  |  |  | 			}, | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	n.conn.AddRule(snatRule) | 
					
						
							|  |  |  | 	return n.conn.Flush() | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-02 19:31:33 +01:00
										 |  |  | // ClampMSSToPMTU ensures that all packets with TCP flags (SYN, ACK, RST) set | 
					
						
							|  |  |  | // being forwarded via the given interface (tun) have MSS set to <MTU of the | 
					
						
							|  |  |  | // interface> - 40 (IP and TCP headers). This can be useful if this tailscale | 
					
						
							|  |  |  | // instance is expected to run as a forwarding proxy, forwarding packets from an | 
					
						
							|  |  |  | // endpoint with higher MTU in an environment where path MTU discovery is | 
					
						
							|  |  |  | // expected to not work (such as the proxies created by the Tailscale Kubernetes | 
					
						
							|  |  |  | // operator). ClamMSSToPMTU creates a new base-chain ts-clamp in the filter | 
					
						
							|  |  |  | // table with accept policy and priority -150. In practice, this means that for | 
					
						
							|  |  |  | // SYN packets the clamp rule in this chain will likely run first and accept the | 
					
						
							|  |  |  | // packet. This is fine because 1) nftables run ALL chains with the same hook | 
					
						
							|  |  |  | // type unless a rule in one of them drops the packet and 2) this chain does not | 
					
						
							|  |  |  | // have functionality to drop the packet- so in practice a matching clamp rule | 
					
						
							|  |  |  | // will always be followed by the custom tailscale filtering rules in the other | 
					
						
							|  |  |  | // chains attached to the filter hook (FORWARD, ts-forward). | 
					
						
							|  |  |  | // We do not want to place the clamping rule into FORWARD/ts-forward chains | 
					
						
							|  |  |  | // because wgengine populates those chains with rules that contain accept | 
					
						
							|  |  |  | // verdicts that would cause no further procesing within that chain. This | 
					
						
							|  |  |  | // functionality is currently invoked from outside wgengine (containerboot), so | 
					
						
							|  |  |  | // we don't want to race with wgengine for rule ordering within chains. | 
					
						
							| 
									
										
										
										
											2023-10-11 07:26:40 -07:00
										 |  |  | func (n *nftablesRunner) ClampMSSToPMTU(tun string, addr netip.Addr) error { | 
					
						
							|  |  |  | 	polAccept := nftables.ChainPolicyAccept | 
					
						
							| 
									
										
										
										
											2024-06-25 14:13:49 +01:00
										 |  |  | 	table, err := n.getNFTByAddr(addr) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("error setting up nftables for IP family of %v: %w", addr, err) | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2023-10-11 07:26:40 -07:00
										 |  |  | 	filterTable, err := createTableIfNotExist(n.conn, table.Proto, "filter") | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("error ensuring filter table: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-02 19:31:33 +01:00
										 |  |  | 	// ensure ts-clamp chain exists | 
					
						
							| 
									
										
										
										
											2023-10-11 07:26:40 -07:00
										 |  |  | 	fwChain, err := getOrCreateChain(n.conn, chainInfo{ | 
					
						
							|  |  |  | 		table:         filterTable, | 
					
						
							| 
									
										
										
										
											2024-04-02 19:31:33 +01:00
										 |  |  | 		name:          "ts-clamp", | 
					
						
							| 
									
										
										
										
											2023-10-11 07:26:40 -07:00
										 |  |  | 		chainType:     nftables.ChainTypeFilter, | 
					
						
							|  |  |  | 		chainHook:     nftables.ChainHookForward, | 
					
						
							| 
									
										
										
										
											2024-04-02 19:31:33 +01:00
										 |  |  | 		chainPriority: nftables.ChainPriorityMangle, | 
					
						
							| 
									
										
										
										
											2023-10-11 07:26:40 -07:00
										 |  |  | 		chainPolicy:   &polAccept, | 
					
						
							|  |  |  | 	}) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("error ensuring forward chain: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	clampRule := &nftables.Rule{ | 
					
						
							|  |  |  | 		Table: filterTable, | 
					
						
							|  |  |  | 		Chain: fwChain, | 
					
						
							|  |  |  | 		Exprs: []expr.Any{ | 
					
						
							|  |  |  | 			&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, | 
					
						
							|  |  |  | 			&expr.Cmp{ | 
					
						
							|  |  |  | 				Op:       expr.CmpOpEq, | 
					
						
							|  |  |  | 				Register: 1, | 
					
						
							|  |  |  | 				Data:     []byte(tun), | 
					
						
							|  |  |  | 			}, | 
					
						
							|  |  |  | 			&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1}, | 
					
						
							|  |  |  | 			&expr.Cmp{ | 
					
						
							|  |  |  | 				Op:       expr.CmpOpEq, | 
					
						
							|  |  |  | 				Register: 1, | 
					
						
							|  |  |  | 				Data:     []byte{unix.IPPROTO_TCP}, | 
					
						
							|  |  |  | 			}, | 
					
						
							|  |  |  | 			&expr.Payload{ | 
					
						
							|  |  |  | 				DestRegister: 1, | 
					
						
							|  |  |  | 				Base:         expr.PayloadBaseTransportHeader, | 
					
						
							|  |  |  | 				Offset:       13, | 
					
						
							|  |  |  | 				Len:          1, | 
					
						
							|  |  |  | 			}, | 
					
						
							|  |  |  | 			&expr.Bitwise{ | 
					
						
							|  |  |  | 				DestRegister:   1, | 
					
						
							|  |  |  | 				SourceRegister: 1, | 
					
						
							|  |  |  | 				Len:            1, | 
					
						
							|  |  |  | 				Mask:           []byte{0x02}, | 
					
						
							|  |  |  | 				Xor:            []byte{0x00}, | 
					
						
							|  |  |  | 			}, | 
					
						
							|  |  |  | 			&expr.Cmp{ | 
					
						
							| 
									
										
										
										
											2024-04-02 19:31:33 +01:00
										 |  |  | 				Op:       expr.CmpOpNeq, // match any packet with a TCP flag set (SYN, ACK, RST) | 
					
						
							| 
									
										
										
										
											2023-10-11 07:26:40 -07:00
										 |  |  | 				Register: 1, | 
					
						
							|  |  |  | 				Data:     []byte{0x00}, | 
					
						
							|  |  |  | 			}, | 
					
						
							|  |  |  | 			&expr.Rt{ | 
					
						
							|  |  |  | 				Register: 1, | 
					
						
							|  |  |  | 				Key:      expr.RtTCPMSS, | 
					
						
							|  |  |  | 			}, | 
					
						
							|  |  |  | 			&expr.Byteorder{ | 
					
						
							|  |  |  | 				DestRegister:   1, | 
					
						
							|  |  |  | 				SourceRegister: 1, | 
					
						
							|  |  |  | 				Op:             expr.ByteorderHton, | 
					
						
							|  |  |  | 				Len:            2, | 
					
						
							|  |  |  | 				Size:           2, | 
					
						
							|  |  |  | 			}, | 
					
						
							|  |  |  | 			&expr.Exthdr{ | 
					
						
							|  |  |  | 				SourceRegister: 1, | 
					
						
							|  |  |  | 				Type:           2, | 
					
						
							|  |  |  | 				Offset:         2, | 
					
						
							|  |  |  | 				Len:            2, | 
					
						
							|  |  |  | 				Op:             expr.ExthdrOpTcpopt, | 
					
						
							|  |  |  | 			}, | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	n.conn.AddRule(clampRule) | 
					
						
							|  |  |  | 	return n.conn.Flush() | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-10-17 19:51:40 +00:00
										 |  |  | // deleteTableIfExists deletes a nftables table via connection c if it exists | 
					
						
							|  |  |  | // within the given family. | 
					
						
							|  |  |  | func deleteTableIfExists(c *nftables.Conn, family nftables.TableFamily, name string) error { | 
					
						
							|  |  |  | 	t, err := getTableIfExists(c, family, name) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("get table: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	if t == nil { | 
					
						
							|  |  |  | 		// Table does not exist, so nothing to delete. | 
					
						
							|  |  |  | 		return nil | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	c.DelTable(t) | 
					
						
							|  |  |  | 	if err := c.Flush(); err != nil { | 
					
						
							|  |  |  | 		if t, err = getTableIfExists(c, family, name); t == nil && err == nil { | 
					
						
							|  |  |  | 			// Check if the table still exists. If it does not, then the error | 
					
						
							|  |  |  | 			// is due to the table not existing, so we can ignore it. Maybe a | 
					
						
							|  |  |  | 			// concurrent process deleted the table. | 
					
						
							|  |  |  | 			return nil | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		return fmt.Errorf("del table: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // getTableIfExists returns the table with the given name from the given family | 
					
						
							|  |  |  | // if it exists. If none match, it returns (nil, nil). | 
					
						
							|  |  |  | func getTableIfExists(c *nftables.Conn, family nftables.TableFamily, name string) (*nftables.Table, error) { | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | 	tables, err := c.ListTables() | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return nil, fmt.Errorf("get tables: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	for _, table := range tables { | 
					
						
							|  |  |  | 		if table.Name == name && table.Family == family { | 
					
						
							|  |  |  | 			return table, nil | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2023-10-17 19:51:40 +00:00
										 |  |  | 	return nil, nil | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-10-17 19:51:40 +00:00
										 |  |  | // createTableIfNotExist creates a nftables table via connection c if it does | 
					
						
							|  |  |  | // not exist within the given family. | 
					
						
							|  |  |  | func createTableIfNotExist(c *nftables.Conn, family nftables.TableFamily, name string) (*nftables.Table, error) { | 
					
						
							|  |  |  | 	if t, err := getTableIfExists(c, family, name); err != nil { | 
					
						
							|  |  |  | 		return nil, fmt.Errorf("get table: %w", err) | 
					
						
							|  |  |  | 	} else if t != nil { | 
					
						
							|  |  |  | 		return t, nil | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | 	t := c.AddTable(&nftables.Table{ | 
					
						
							|  |  |  | 		Family: family, | 
					
						
							|  |  |  | 		Name:   name, | 
					
						
							|  |  |  | 	}) | 
					
						
							|  |  |  | 	if err := c.Flush(); err != nil { | 
					
						
							|  |  |  | 		return nil, fmt.Errorf("add table: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return t, nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | type errorChainNotFound struct { | 
					
						
							|  |  |  | 	chainName string | 
					
						
							|  |  |  | 	tableName string | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (e errorChainNotFound) Error() string { | 
					
						
							|  |  |  | 	return fmt.Sprintf("chain %s not found in table %s", e.chainName, e.tableName) | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // getChainFromTable returns the chain with the given name from the given table. | 
					
						
							|  |  |  | // Note that a chain name is unique within a table. | 
					
						
							|  |  |  | func getChainFromTable(c *nftables.Conn, table *nftables.Table, name string) (*nftables.Chain, error) { | 
					
						
							|  |  |  | 	chains, err := c.ListChainsOfTableFamily(table.Family) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return nil, fmt.Errorf("list chains: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	for _, chain := range chains { | 
					
						
							|  |  |  | 		// Table family is already checked so table name is unique | 
					
						
							|  |  |  | 		if chain.Table.Name == table.Name && chain.Name == name { | 
					
						
							|  |  |  | 			return chain, nil | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return nil, errorChainNotFound{table.Name, name} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-21 09:21:59 -07:00
										 |  |  | // isTSChain reports whether `name` begins with "ts-" (and is thus a | 
					
						
							|  |  |  | // Tailscale-managed chain). | 
					
						
							| 
									
										
										
										
											2023-08-18 18:41:28 +00:00
										 |  |  | func isTSChain(name string) bool { | 
					
						
							|  |  |  | 	return strings.HasPrefix(name, "ts-") | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | // createChainIfNotExist creates a chain with the given name in the given table | 
					
						
							|  |  |  | // if it does not exist. | 
					
						
							|  |  |  | func createChainIfNotExist(c *nftables.Conn, cinfo chainInfo) error { | 
					
						
							| 
									
										
										
										
											2023-10-11 07:26:40 -07:00
										 |  |  | 	_, err := getOrCreateChain(c, cinfo) | 
					
						
							|  |  |  | 	return err | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func getOrCreateChain(c *nftables.Conn, cinfo chainInfo) (*nftables.Chain, error) { | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | 	chain, err := getChainFromTable(c, cinfo.table, cinfo.name) | 
					
						
							|  |  |  | 	if err != nil && !errors.Is(err, errorChainNotFound{cinfo.table.Name, cinfo.name}) { | 
					
						
							| 
									
										
										
										
											2023-10-11 07:26:40 -07:00
										 |  |  | 		return nil, fmt.Errorf("get chain: %w", err) | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | 	} else if err == nil { | 
					
						
							| 
									
										
										
										
											2023-08-18 18:41:28 +00:00
										 |  |  | 		// The chain already exists. If it is a TS chain, check the | 
					
						
							|  |  |  | 		// type/hook/priority, but for "conventional chains" assume they're what | 
					
						
							|  |  |  | 		// we expect (in case iptables-nft/ufw make minor behavior changes in | 
					
						
							|  |  |  | 		// the future). | 
					
						
							| 
									
										
										
										
											2024-04-05 19:43:58 +01:00
										 |  |  | 		if isTSChain(chain.Name) && (chain.Type != cinfo.chainType || *chain.Hooknum != *cinfo.chainHook || *chain.Priority != *cinfo.chainPriority) { | 
					
						
							| 
									
										
										
										
											2023-10-11 07:26:40 -07:00
										 |  |  | 			return nil, fmt.Errorf("chain %s already exists with different type/hook/priority", cinfo.name) | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | 		} | 
					
						
							| 
									
										
										
										
											2023-10-11 07:26:40 -07:00
										 |  |  | 		return chain, nil | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-12 15:55:02 +00:00
										 |  |  | 	chain = c.AddChain(&nftables.Chain{ | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | 		Name:     cinfo.name, | 
					
						
							|  |  |  | 		Table:    cinfo.table, | 
					
						
							|  |  |  | 		Type:     cinfo.chainType, | 
					
						
							|  |  |  | 		Hooknum:  cinfo.chainHook, | 
					
						
							|  |  |  | 		Priority: cinfo.chainPriority, | 
					
						
							| 
									
										
										
										
											2023-08-18 18:41:28 +00:00
										 |  |  | 		Policy:   cinfo.chainPolicy, | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | 	}) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if err := c.Flush(); err != nil { | 
					
						
							| 
									
										
										
										
											2023-10-11 07:26:40 -07:00
										 |  |  | 		return nil, fmt.Errorf("add chain: %w", err) | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-10-11 07:26:40 -07:00
										 |  |  | 	return chain, nil | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-10-10 18:26:52 -07:00
										 |  |  | // NetfilterRunner abstracts helpers to run netfilter commands. It is | 
					
						
							|  |  |  | // implemented by linuxfw.IPTablesRunner and linuxfw.NfTablesRunner. | 
					
						
							|  |  |  | type NetfilterRunner interface { | 
					
						
							|  |  |  | 	// AddLoopbackRule adds a rule to permit loopback traffic to addr. This rule | 
					
						
							|  |  |  | 	// is added only if it does not already exist. | 
					
						
							|  |  |  | 	AddLoopbackRule(addr netip.Addr) error | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// DelLoopbackRule removes the rule added by AddLoopbackRule. | 
					
						
							|  |  |  | 	DelLoopbackRule(addr netip.Addr) error | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// AddHooks adds rules to conventional chains like "FORWARD", "INPUT" and | 
					
						
							|  |  |  | 	// "POSTROUTING" to jump from those chains to tailscale chains. | 
					
						
							|  |  |  | 	AddHooks() error | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// DelHooks deletes rules added by AddHooks. | 
					
						
							|  |  |  | 	DelHooks(logf logger.Logf) error | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// AddChains creates custom Tailscale chains. | 
					
						
							|  |  |  | 	AddChains() error | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// DelChains removes chains added by AddChains. | 
					
						
							|  |  |  | 	DelChains() error | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// AddBase adds rules reused by different other rules. | 
					
						
							|  |  |  | 	AddBase(tunname string) error | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// DelBase removes rules added by AddBase. | 
					
						
							|  |  |  | 	DelBase() error | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// AddSNATRule adds the netfilter rule to SNAT incoming traffic over | 
					
						
							|  |  |  | 	// the Tailscale interface destined for local subnets. An error is | 
					
						
							|  |  |  | 	// returned if the rule already exists. | 
					
						
							|  |  |  | 	AddSNATRule() error | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// DelSNATRule removes the rule added by AddSNATRule. | 
					
						
							|  |  |  | 	DelSNATRule() error | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-05-06 15:22:17 -07:00
										 |  |  | 	// AddStatefulRule adds a netfilter rule for stateful packet filtering | 
					
						
							|  |  |  | 	// using conntrack. | 
					
						
							|  |  |  | 	AddStatefulRule(tunname string) error | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// DelStatefulRule removes a netfilter rule for stateful packet filtering | 
					
						
							|  |  |  | 	// using conntrack. | 
					
						
							|  |  |  | 	DelStatefulRule(tunname string) error | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-10-10 18:26:52 -07:00
										 |  |  | 	// HasIPV6 reports true if the system supports IPv6. | 
					
						
							|  |  |  | 	HasIPV6() bool | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// HasIPV6NAT reports true if the system supports IPv6 NAT. | 
					
						
							|  |  |  | 	HasIPV6NAT() bool | 
					
						
							| 
									
										
										
										
											2023-10-11 07:26:40 -07:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-29 05:23:03 +00:00
										 |  |  | 	// HasIPV6Filter reports true if the system supports IPv6 filter tables | 
					
						
							|  |  |  | 	// This is only meaningful for iptables implementation, where hosts have | 
					
						
							|  |  |  | 	// partial ipables support (i.e missing filter table). For nftables | 
					
						
							|  |  |  | 	// implementation, this will default to the value of HasIPv6(). | 
					
						
							|  |  |  | 	HasIPV6Filter() bool | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-10-11 07:26:40 -07:00
										 |  |  | 	// AddDNATRule adds a rule to the nat/PREROUTING chain to DNAT traffic | 
					
						
							|  |  |  | 	// destined for the given original destination to the given new destination. | 
					
						
							|  |  |  | 	// This is used to forward all traffic destined for the Tailscale interface | 
					
						
							|  |  |  | 	// to the provided destination, as used in the Kubernetes ingress proxies. | 
					
						
							|  |  |  | 	AddDNATRule(origDst, dst netip.Addr) error | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-23 17:30:00 +01:00
										 |  |  | 	// DNATWithLoadBalancer adds a rule to the nat/PREROUTING chain to DNAT | 
					
						
							|  |  |  | 	// traffic destined for the given original destination to the given new | 
					
						
							|  |  |  | 	// destination(s) using round robin to load balance if more than one | 
					
						
							|  |  |  | 	// destination is provided. This is used to forward all traffic destined | 
					
						
							|  |  |  | 	// for the Tailscale interface to the provided destination(s), as used | 
					
						
							|  |  |  | 	// in the Kubernetes ingress proxies. | 
					
						
							|  |  |  | 	DNATWithLoadBalancer(origDst netip.Addr, dsts []netip.Addr) error | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-10-11 07:26:40 -07:00
										 |  |  | 	// AddSNATRuleForDst adds a rule to the nat/POSTROUTING chain to SNAT | 
					
						
							|  |  |  | 	// traffic destined for dst to src. | 
					
						
							|  |  |  | 	// This is used to forward traffic destined for the local machine over | 
					
						
							|  |  |  | 	// the Tailscale interface, as used in the Kubernetes egress proxies. | 
					
						
							|  |  |  | 	AddSNATRuleForDst(src, dst netip.Addr) error | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// DNATNonTailscaleTraffic adds a rule to the nat/PREROUTING chain to DNAT | 
					
						
							|  |  |  | 	// all traffic inbound from any interface except exemptInterface to dst. | 
					
						
							|  |  |  | 	// This is used to forward traffic destined for the local machine over | 
					
						
							| 
									
										
										
										
											2024-04-23 17:30:00 +01:00
										 |  |  | 	// the Tailscale interface, as used in the Kubernetes egress proxies. | 
					
						
							| 
									
										
										
										
											2023-10-11 07:26:40 -07:00
										 |  |  | 	DNATNonTailscaleTraffic(exemptInterface string, dst netip.Addr) error | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// ClampMSSToPMTU adds a rule to the mangle/FORWARD chain to clamp MSS for | 
					
						
							|  |  |  | 	// traffic destined for the provided tun interface. | 
					
						
							|  |  |  | 	ClampMSSToPMTU(tun string, addr netip.Addr) error | 
					
						
							| 
									
										
										
										
											2023-12-05 18:12:02 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  | 	// AddMagicsockPortRule adds a rule to the ts-input chain to accept | 
					
						
							|  |  |  | 	// incoming traffic on the specified port, to allow magicsock to | 
					
						
							|  |  |  | 	// communicate. | 
					
						
							|  |  |  | 	AddMagicsockPortRule(port uint16, network string) error | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// DelMagicsockPortRule removes the rule created by AddMagicsockPortRule, | 
					
						
							|  |  |  | 	// if it exists. | 
					
						
							|  |  |  | 	DelMagicsockPortRule(port uint16, network string) error | 
					
						
							| 
									
										
										
										
											2023-10-10 18:26:52 -07:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-04 12:08:56 -05:00
										 |  |  | // New creates a NetfilterRunner, auto-detecting whether to use | 
					
						
							|  |  |  | // nftables or iptables. | 
					
						
							|  |  |  | // As nftables is still experimental, iptables will be used unless | 
					
						
							|  |  |  | // either the TS_DEBUG_FIREWALL_MODE environment variable, or the prefHint | 
					
						
							|  |  |  | // parameter, is set to one of "nftables" or "auto". | 
					
						
							|  |  |  | func New(logf logger.Logf, prefHint string) (NetfilterRunner, error) { | 
					
						
							|  |  |  | 	mode := detectFirewallMode(logf, prefHint) | 
					
						
							| 
									
										
										
										
											2023-10-10 18:26:52 -07:00
										 |  |  | 	switch mode { | 
					
						
							|  |  |  | 	case FirewallModeIPTables: | 
					
						
							| 
									
										
										
										
											2024-08-03 08:26:12 -07:00
										 |  |  | 		// Note that we don't simply return an newIPTablesRunner here because it | 
					
						
							|  |  |  | 		// would return a `nil` iptablesRunner which is different from returning | 
					
						
							|  |  |  | 		// a nil NetfilterRunner. | 
					
						
							|  |  |  | 		ipr, err := newIPTablesRunner(logf) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			return nil, err | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		return ipr, nil | 
					
						
							| 
									
										
										
										
											2023-10-10 18:26:52 -07:00
										 |  |  | 	case FirewallModeNfTables: | 
					
						
							| 
									
										
										
										
											2024-08-03 08:26:12 -07:00
										 |  |  | 		// Note that we don't simply return an newNfTablesRunner here because it | 
					
						
							|  |  |  | 		// would return a `nil` nftablesRunner which is different from returning | 
					
						
							|  |  |  | 		// a nil NetfilterRunner. | 
					
						
							|  |  |  | 		nfr, err := newNfTablesRunner(logf) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			return nil, err | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		return nfr, nil | 
					
						
							| 
									
										
										
										
											2023-10-10 18:26:52 -07:00
										 |  |  | 	default: | 
					
						
							|  |  |  | 		return nil, fmt.Errorf("unknown firewall mode %v", mode) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // newNfTablesRunner creates a new nftablesRunner without guaranteeing | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | // the existence of the tables and chains. | 
					
						
							| 
									
										
										
										
											2023-10-10 18:26:52 -07:00
										 |  |  | func newNfTablesRunner(logf logger.Logf) (*nftablesRunner, error) { | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | 	conn, err := nftables.New() | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return nil, fmt.Errorf("nftables connection: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2024-05-14 08:51:53 +01:00
										 |  |  | 	return newNfTablesRunnerWithConn(logf, conn), nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func newNfTablesRunnerWithConn(logf logger.Logf, conn *nftables.Conn) *nftablesRunner { | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | 	nft4 := &nftable{Proto: nftables.TableFamilyIPv4} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-21 15:43:12 -07:00
										 |  |  | 	v6err := CheckIPv6(logf) | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | 	if v6err != nil { | 
					
						
							| 
									
										
										
										
											2023-09-02 11:09:38 -07:00
										 |  |  | 		logf("disabling tunneled IPv6 due to system IPv6 config: %v", v6err) | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | 	} | 
					
						
							|  |  |  | 	supportsV6 := v6err == nil | 
					
						
							|  |  |  | 	var nft6 *nftable | 
					
						
							| 
									
										
										
										
											2024-03-06 21:53:51 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | 	if supportsV6 { | 
					
						
							|  |  |  | 		nft6 = &nftable{Proto: nftables.TableFamilyIPv6} | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2024-05-14 08:51:53 +01:00
										 |  |  | 	logf("netfilter running in nftables mode, v6 = %v", supportsV6) | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | 	// TODO(KevinLiang10): convert iptables rule to nftable rules if they exist in the iptables | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return &nftablesRunner{ | 
					
						
							| 
									
										
										
										
											2024-03-29 05:23:03 +00:00
										 |  |  | 		conn:        conn, | 
					
						
							|  |  |  | 		nft4:        nft4, | 
					
						
							|  |  |  | 		nft6:        nft6, | 
					
						
							|  |  |  | 		v6Available: supportsV6, | 
					
						
							| 
									
										
										
										
											2024-05-14 08:51:53 +01:00
										 |  |  | 	} | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // newLoadSaddrExpr creates a new nftables expression that loads the source | 
					
						
							|  |  |  | // address of the packet into the given register. | 
					
						
							|  |  |  | func newLoadSaddrExpr(proto nftables.TableFamily, destReg uint32) (expr.Any, error) { | 
					
						
							|  |  |  | 	switch proto { | 
					
						
							|  |  |  | 	case nftables.TableFamilyIPv4: | 
					
						
							|  |  |  | 		return &expr.Payload{ | 
					
						
							|  |  |  | 			DestRegister: destReg, | 
					
						
							|  |  |  | 			Base:         expr.PayloadBaseNetworkHeader, | 
					
						
							|  |  |  | 			Offset:       12, | 
					
						
							|  |  |  | 			Len:          4, | 
					
						
							|  |  |  | 		}, nil | 
					
						
							|  |  |  | 	case nftables.TableFamilyIPv6: | 
					
						
							|  |  |  | 		return &expr.Payload{ | 
					
						
							|  |  |  | 			DestRegister: destReg, | 
					
						
							|  |  |  | 			Base:         expr.PayloadBaseNetworkHeader, | 
					
						
							|  |  |  | 			Offset:       8, | 
					
						
							|  |  |  | 			Len:          16, | 
					
						
							|  |  |  | 		}, nil | 
					
						
							|  |  |  | 	default: | 
					
						
							|  |  |  | 		return nil, fmt.Errorf("table family %v is neither IPv4 nor IPv6", proto) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-05 18:12:02 -05:00
										 |  |  | // newLoadDportExpr creates a new nftables express that loads the desination port | 
					
						
							|  |  |  | // of a TCP/UDP packet into the given register. | 
					
						
							|  |  |  | func newLoadDportExpr(destReg uint32) expr.Any { | 
					
						
							|  |  |  | 	return &expr.Payload{ | 
					
						
							|  |  |  | 		DestRegister: destReg, | 
					
						
							|  |  |  | 		Base:         expr.PayloadBaseTransportHeader, | 
					
						
							|  |  |  | 		Offset:       2, | 
					
						
							|  |  |  | 		Len:          2, | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-10-10 18:26:52 -07:00
										 |  |  | // HasIPV6 reports true if the system supports IPv6. | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | func (n *nftablesRunner) HasIPV6() bool { | 
					
						
							|  |  |  | 	return n.v6Available | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-29 05:23:03 +00:00
										 |  |  | // HasIPV6NAT returns true if the system supports IPv6. | 
					
						
							|  |  |  | // Kernel support for nftables was added after support for IPv6 | 
					
						
							|  |  |  | // NAT, so no need for a separate IPv6 NAT support check like we do for iptables. | 
					
						
							|  |  |  | // https://tldp.org/HOWTO/Linux+IPv6-HOWTO/ch18s04.html | 
					
						
							|  |  |  | // https://wiki.nftables.org/wiki-nftables/index.php/Building_and_installing_nftables_from_sources | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | func (n *nftablesRunner) HasIPV6NAT() bool { | 
					
						
							| 
									
										
										
										
											2024-03-29 05:23:03 +00:00
										 |  |  | 	return n.v6Available | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // HasIPV6Filter returns true if system supports IPv6. There are no known edge | 
					
						
							|  |  |  | // cases where nftables running on a host that supports IPv6 would not support | 
					
						
							|  |  |  | // filter table. | 
					
						
							|  |  |  | func (n *nftablesRunner) HasIPV6Filter() bool { | 
					
						
							|  |  |  | 	return n.v6Available | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // findRule iterates through the rules to find the rule with matching expressions. | 
					
						
							|  |  |  | func findRule(conn *nftables.Conn, rule *nftables.Rule) (*nftables.Rule, error) { | 
					
						
							|  |  |  | 	rules, err := conn.GetRules(rule.Table, rule.Chain) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return nil, fmt.Errorf("get nftables rules: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	if len(rules) == 0 { | 
					
						
							|  |  |  | 		return nil, nil | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ruleLoop: | 
					
						
							|  |  |  | 	for _, r := range rules { | 
					
						
							|  |  |  | 		if len(r.Exprs) != len(rule.Exprs) { | 
					
						
							|  |  |  | 			continue | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		for i, e := range r.Exprs { | 
					
						
							| 
									
										
										
										
											2023-08-18 18:41:28 +00:00
										 |  |  | 			// Skip counter expressions, as they will not match. | 
					
						
							|  |  |  | 			if _, ok := e.(*expr.Counter); ok { | 
					
						
							|  |  |  | 				continue | 
					
						
							|  |  |  | 			} | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | 			if !reflect.DeepEqual(e, rule.Exprs[i]) { | 
					
						
							|  |  |  | 				continue ruleLoop | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		return r, nil | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return nil, nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func createLoopbackRule( | 
					
						
							|  |  |  | 	proto nftables.TableFamily, | 
					
						
							|  |  |  | 	table *nftables.Table, | 
					
						
							|  |  |  | 	chain *nftables.Chain, | 
					
						
							|  |  |  | 	addr netip.Addr, | 
					
						
							|  |  |  | ) (*nftables.Rule, error) { | 
					
						
							|  |  |  | 	saddrExpr, err := newLoadSaddrExpr(proto, 1) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return nil, fmt.Errorf("newLoadSaddrExpr: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	loopBackRule := &nftables.Rule{ | 
					
						
							|  |  |  | 		Table: table, | 
					
						
							|  |  |  | 		Chain: chain, | 
					
						
							|  |  |  | 		Exprs: []expr.Any{ | 
					
						
							|  |  |  | 			&expr.Meta{ | 
					
						
							|  |  |  | 				Key:      expr.MetaKeyIIFNAME, | 
					
						
							|  |  |  | 				Register: 1, | 
					
						
							|  |  |  | 			}, | 
					
						
							|  |  |  | 			&expr.Cmp{ | 
					
						
							|  |  |  | 				Op:       expr.CmpOpEq, | 
					
						
							|  |  |  | 				Register: 1, | 
					
						
							|  |  |  | 				Data:     []byte("lo"), | 
					
						
							|  |  |  | 			}, | 
					
						
							|  |  |  | 			saddrExpr, | 
					
						
							|  |  |  | 			&expr.Cmp{ | 
					
						
							|  |  |  | 				Op:       expr.CmpOpEq, | 
					
						
							|  |  |  | 				Register: 1, | 
					
						
							|  |  |  | 				Data:     addr.AsSlice(), | 
					
						
							|  |  |  | 			}, | 
					
						
							|  |  |  | 			&expr.Counter{}, | 
					
						
							|  |  |  | 			&expr.Verdict{ | 
					
						
							|  |  |  | 				Kind: expr.VerdictAccept, | 
					
						
							|  |  |  | 			}, | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return loopBackRule, nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // insertLoopbackRule inserts the TS loop back rule into | 
					
						
							|  |  |  | // the given chain as the first rule if it does not exist. | 
					
						
							|  |  |  | func insertLoopbackRule( | 
					
						
							|  |  |  | 	conn *nftables.Conn, proto nftables.TableFamily, | 
					
						
							|  |  |  | 	table *nftables.Table, chain *nftables.Chain, addr netip.Addr) error { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	loopBackRule, err := createLoopbackRule(proto, table, chain, addr) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("create loopback rule: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// If TestDial is set, we are running in test mode and we should not | 
					
						
							|  |  |  | 	// find rule because header will mismatch. | 
					
						
							|  |  |  | 	if conn.TestDial == nil { | 
					
						
							|  |  |  | 		// Check if the rule already exists. | 
					
						
							|  |  |  | 		rule, err := findRule(conn, loopBackRule) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			return fmt.Errorf("find rule: %w", err) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		if rule != nil { | 
					
						
							|  |  |  | 			// Rule already exists, no need to insert. | 
					
						
							|  |  |  | 			return nil | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// This inserts the rule to the top of the chain | 
					
						
							|  |  |  | 	_ = conn.InsertRule(loopBackRule) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if err = conn.Flush(); err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("insert rule: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // getNFTByAddr returns the nftables with correct IP family | 
					
						
							|  |  |  | // that we will be using for the given address. | 
					
						
							| 
									
										
										
										
											2024-06-25 14:13:49 +01:00
										 |  |  | func (n *nftablesRunner) getNFTByAddr(addr netip.Addr) (*nftable, error) { | 
					
						
							|  |  |  | 	if addr.Is6() && !n.v6Available { | 
					
						
							|  |  |  | 		return nil, fmt.Errorf("nftables for IPv6 are not available on this host") | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | 	if addr.Is6() { | 
					
						
							| 
									
										
										
										
											2024-06-25 14:13:49 +01:00
										 |  |  | 		return n.nft6, nil | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | 	} | 
					
						
							| 
									
										
										
										
											2024-06-25 14:13:49 +01:00
										 |  |  | 	return n.nft4, nil | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // AddLoopbackRule adds an nftables rule to permit loopback traffic to | 
					
						
							|  |  |  | // a local Tailscale IP. This rule is added only if it does not already exist. | 
					
						
							|  |  |  | func (n *nftablesRunner) AddLoopbackRule(addr netip.Addr) error { | 
					
						
							| 
									
										
										
										
											2024-06-25 14:13:49 +01:00
										 |  |  | 	nf, err := n.getNFTByAddr(addr) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("error setting up nftables for IP family of %v: %w", addr, err) | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | 	inputChain, err := getChainFromTable(n.conn, nf.Filter, chainNameInput) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("get input chain: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if err := insertLoopbackRule(n.conn, nf.Proto, nf.Filter, inputChain, addr); err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("add loopback rule: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // DelLoopbackRule removes the nftables rule permitting loopback | 
					
						
							|  |  |  | // traffic to a Tailscale IP. | 
					
						
							|  |  |  | func (n *nftablesRunner) DelLoopbackRule(addr netip.Addr) error { | 
					
						
							| 
									
										
										
										
											2024-06-25 14:13:49 +01:00
										 |  |  | 	nf, err := n.getNFTByAddr(addr) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("error setting up nftables for IP family of %v: %w", addr, err) | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | 	inputChain, err := getChainFromTable(n.conn, nf.Filter, chainNameInput) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("get input chain: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	loopBackRule, err := createLoopbackRule(nf.Proto, nf.Filter, inputChain, addr) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("create loopback rule: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	existingLoopBackRule, err := findRule(n.conn, loopBackRule) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("find loop back rule: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	if existingLoopBackRule == nil { | 
					
						
							|  |  |  | 		// Rule does not exist, no need to delete. | 
					
						
							|  |  |  | 		return nil | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if err := n.conn.DelRule(existingLoopBackRule); err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("delete rule: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return n.conn.Flush() | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-05-14 08:51:53 +01:00
										 |  |  | // getTables returns tables for IP families that this host was determined to | 
					
						
							|  |  |  | // support (either IPv4 and IPv6 or just IPv4). | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | func (n *nftablesRunner) getTables() []*nftable { | 
					
						
							| 
									
										
										
										
											2024-05-14 08:51:53 +01:00
										 |  |  | 	if n.HasIPV6() { | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | 		return []*nftable{n.nft4, n.nft6} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return []*nftable{n.nft4} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // AddChains creates custom Tailscale chains in netfilter via nftables | 
					
						
							|  |  |  | // if the ts-chain doesn't already exist. | 
					
						
							|  |  |  | func (n *nftablesRunner) AddChains() error { | 
					
						
							| 
									
										
										
										
											2023-08-18 18:41:28 +00:00
										 |  |  | 	polAccept := nftables.ChainPolicyAccept | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | 	for _, table := range n.getTables() { | 
					
						
							| 
									
										
										
										
											2023-08-18 18:41:28 +00:00
										 |  |  | 		// Create the filter table if it doesn't exist, this table name is the same | 
					
						
							|  |  |  | 		// as the name used by iptables-nft and ufw. We install rules into the | 
					
						
							|  |  |  | 		// same conventional table so that `accept` verdicts from our jump | 
					
						
							|  |  |  | 		// chains are conclusive. | 
					
						
							|  |  |  | 		filter, err := createTableIfNotExist(n.conn, table.Proto, "filter") | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			return fmt.Errorf("create table: %w", err) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		table.Filter = filter | 
					
						
							| 
									
										
										
										
											2023-08-18 18:41:28 +00:00
										 |  |  | 		// Adding the "conventional chains" that are used by iptables-nft and ufw. | 
					
						
							|  |  |  | 		if err = createChainIfNotExist(n.conn, chainInfo{filter, "FORWARD", nftables.ChainTypeFilter, nftables.ChainHookForward, nftables.ChainPriorityFilter, &polAccept}); err != nil { | 
					
						
							|  |  |  | 			return fmt.Errorf("create forward chain: %w", err) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		if err = createChainIfNotExist(n.conn, chainInfo{filter, "INPUT", nftables.ChainTypeFilter, nftables.ChainHookInput, nftables.ChainPriorityFilter, &polAccept}); err != nil { | 
					
						
							|  |  |  | 			return fmt.Errorf("create input chain: %w", err) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		// Adding the tailscale chains that contain our rules. | 
					
						
							|  |  |  | 		if err = createChainIfNotExist(n.conn, chainInfo{filter, chainNameForward, chainTypeRegular, nil, nil, nil}); err != nil { | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | 			return fmt.Errorf("create forward chain: %w", err) | 
					
						
							|  |  |  | 		} | 
					
						
							| 
									
										
										
										
											2023-08-18 18:41:28 +00:00
										 |  |  | 		if err = createChainIfNotExist(n.conn, chainInfo{filter, chainNameInput, chainTypeRegular, nil, nil, nil}); err != nil { | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | 			return fmt.Errorf("create input chain: %w", err) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-18 18:41:28 +00:00
										 |  |  | 		// Create the nat table if it doesn't exist, this table name is the same | 
					
						
							|  |  |  | 		// as the name used by iptables-nft and ufw. We install rules into the | 
					
						
							|  |  |  | 		// same conventional table so that `accept` verdicts from our jump | 
					
						
							|  |  |  | 		// chains are conclusive. | 
					
						
							|  |  |  | 		nat, err := createTableIfNotExist(n.conn, table.Proto, "nat") | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			return fmt.Errorf("create table: %w", err) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		table.Nat = nat | 
					
						
							| 
									
										
										
										
											2023-08-18 18:41:28 +00:00
										 |  |  | 		// Adding the "conventional chains" that are used by iptables-nft and ufw. | 
					
						
							|  |  |  | 		if err = createChainIfNotExist(n.conn, chainInfo{nat, "POSTROUTING", nftables.ChainTypeNAT, nftables.ChainHookPostrouting, nftables.ChainPriorityNATSource, &polAccept}); err != nil { | 
					
						
							|  |  |  | 			return fmt.Errorf("create postrouting chain: %w", err) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		// Adding the tailscale chain that contains our rules. | 
					
						
							|  |  |  | 		if err = createChainIfNotExist(n.conn, chainInfo{nat, chainNamePostrouting, chainTypeRegular, nil, nil, nil}); err != nil { | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | 			return fmt.Errorf("create postrouting chain: %w", err) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return n.conn.Flush() | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-10-17 19:51:40 +00:00
										 |  |  | // These are dummy chains and tables we create to detect if nftables is | 
					
						
							|  |  |  | // available. We create them, then delete them. If we can create and delete | 
					
						
							|  |  |  | // them, then we can use nftables. If we can't, then we assume that we're | 
					
						
							|  |  |  | // running on a system that doesn't support nftables. See | 
					
						
							|  |  |  | // createDummyPostroutingChains. | 
					
						
							|  |  |  | const ( | 
					
						
							|  |  |  | 	tsDummyChainName = "ts-test-postrouting" | 
					
						
							|  |  |  | 	tsDummyTableName = "ts-test-nat" | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // createDummyPostroutingChains creates dummy postrouting chains in netfilter | 
					
						
							|  |  |  | // via netfilter via nftables, as a last resort measure to detect that nftables | 
					
						
							|  |  |  | // can be used. It cleans up the dummy chains after creation. | 
					
						
							|  |  |  | func (n *nftablesRunner) createDummyPostroutingChains() (retErr error) { | 
					
						
							|  |  |  | 	polAccept := ptr.To(nftables.ChainPolicyAccept) | 
					
						
							| 
									
										
										
										
											2024-05-14 08:51:53 +01:00
										 |  |  | 	for _, table := range n.getTables() { | 
					
						
							| 
									
										
										
										
											2023-10-17 19:51:40 +00:00
										 |  |  | 		nat, err := createTableIfNotExist(n.conn, table.Proto, tsDummyTableName) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			return fmt.Errorf("create nat table: %w", err) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		defer func(fm nftables.TableFamily) { | 
					
						
							| 
									
										
										
										
											2023-10-18 11:22:28 -07:00
										 |  |  | 			if err := deleteTableIfExists(n.conn, fm, tsDummyTableName); err != nil && retErr == nil { | 
					
						
							| 
									
										
										
										
											2023-10-17 19:51:40 +00:00
										 |  |  | 				retErr = fmt.Errorf("delete %q table: %w", tsDummyTableName, err) | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 		}(table.Proto) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		table.Nat = nat | 
					
						
							|  |  |  | 		if err = createChainIfNotExist(n.conn, chainInfo{nat, tsDummyChainName, nftables.ChainTypeNAT, nftables.ChainHookPostrouting, nftables.ChainPriorityNATSource, polAccept}); err != nil { | 
					
						
							|  |  |  | 			return fmt.Errorf("create %q chain: %w", tsDummyChainName, err) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		if err := deleteChainIfExists(n.conn, nat, tsDummyChainName); err != nil { | 
					
						
							|  |  |  | 			return fmt.Errorf("delete %q chain: %w", tsDummyChainName, err) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | // deleteChainIfExists deletes a chain if it exists. | 
					
						
							|  |  |  | func deleteChainIfExists(c *nftables.Conn, table *nftables.Table, name string) error { | 
					
						
							|  |  |  | 	chain, err := getChainFromTable(c, table, name) | 
					
						
							|  |  |  | 	if err != nil && !errors.Is(err, errorChainNotFound{table.Name, name}) { | 
					
						
							|  |  |  | 		return fmt.Errorf("get chain: %w", err) | 
					
						
							|  |  |  | 	} else if err != nil { | 
					
						
							|  |  |  | 		// If the chain doesn't exist, we don't need to delete it. | 
					
						
							|  |  |  | 		return nil | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	c.FlushChain(chain) | 
					
						
							|  |  |  | 	c.DelChain(chain) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if err := c.Flush(); err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("flush and delete chain: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // DelChains removes the custom Tailscale chains from netfilter via nftables. | 
					
						
							|  |  |  | func (n *nftablesRunner) DelChains() error { | 
					
						
							|  |  |  | 	for _, table := range n.getTables() { | 
					
						
							|  |  |  | 		if err := deleteChainIfExists(n.conn, table.Filter, chainNameForward); err != nil { | 
					
						
							|  |  |  | 			return fmt.Errorf("delete chain: %w", err) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		if err := deleteChainIfExists(n.conn, table.Filter, chainNameInput); err != nil { | 
					
						
							|  |  |  | 			return fmt.Errorf("delete chain: %w", err) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if err := deleteChainIfExists(n.conn, n.nft4.Nat, chainNamePostrouting); err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("delete chain: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-05-14 08:51:53 +01:00
										 |  |  | 	if n.HasIPV6NAT() { | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | 		if err := deleteChainIfExists(n.conn, n.nft6.Nat, chainNamePostrouting); err != nil { | 
					
						
							|  |  |  | 			return fmt.Errorf("delete chain: %w", err) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if err := n.conn.Flush(); err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("flush: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-18 18:41:28 +00:00
										 |  |  | // createHookRule creates a rule to jump from a hooked chain to a regular chain. | 
					
						
							|  |  |  | func createHookRule(table *nftables.Table, fromChain *nftables.Chain, toChainName string) *nftables.Rule { | 
					
						
							|  |  |  | 	exprs := []expr.Any{ | 
					
						
							|  |  |  | 		&expr.Counter{}, | 
					
						
							|  |  |  | 		&expr.Verdict{ | 
					
						
							|  |  |  | 			Kind:  expr.VerdictJump, | 
					
						
							|  |  |  | 			Chain: toChainName, | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	rule := &nftables.Rule{ | 
					
						
							|  |  |  | 		Table: table, | 
					
						
							|  |  |  | 		Chain: fromChain, | 
					
						
							|  |  |  | 		Exprs: exprs, | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return rule | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // addHookRule adds a rule to jump from a hooked chain to a regular chain at top of the hooked chain. | 
					
						
							|  |  |  | func addHookRule(conn *nftables.Conn, table *nftables.Table, fromChain *nftables.Chain, toChainName string) error { | 
					
						
							|  |  |  | 	rule := createHookRule(table, fromChain, toChainName) | 
					
						
							|  |  |  | 	_ = conn.InsertRule(rule) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if err := conn.Flush(); err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("flush add rule: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // AddHooks is adding rules to conventional chains like "FORWARD", "INPUT" and "POSTROUTING" | 
					
						
							|  |  |  | // in tables and jump from those chains to tailscale chains. | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | func (n *nftablesRunner) AddHooks() error { | 
					
						
							| 
									
										
										
										
											2023-08-18 18:41:28 +00:00
										 |  |  | 	conn := n.conn | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	for _, table := range n.getTables() { | 
					
						
							|  |  |  | 		inputChain, err := getChainFromTable(conn, table.Filter, "INPUT") | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			return fmt.Errorf("get INPUT chain: %w", err) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		err = addHookRule(conn, table.Filter, inputChain, chainNameInput) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			return fmt.Errorf("Addhook: %w", err) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		forwardChain, err := getChainFromTable(conn, table.Filter, "FORWARD") | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			return fmt.Errorf("get FORWARD chain: %w", err) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		err = addHookRule(conn, table.Filter, forwardChain, chainNameForward) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			return fmt.Errorf("Addhook: %w", err) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		postroutingChain, err := getChainFromTable(conn, table.Nat, "POSTROUTING") | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			return fmt.Errorf("get INPUT chain: %w", err) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		err = addHookRule(conn, table.Nat, postroutingChain, chainNamePostrouting) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			return fmt.Errorf("Addhook: %w", err) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | 	return nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-18 18:41:28 +00:00
										 |  |  | // delHookRule deletes a rule that jumps from a hooked chain to a regular chain. | 
					
						
							|  |  |  | func delHookRule(conn *nftables.Conn, table *nftables.Table, fromChain *nftables.Chain, toChainName string) error { | 
					
						
							|  |  |  | 	rule := createHookRule(table, fromChain, toChainName) | 
					
						
							|  |  |  | 	existingRule, err := findRule(conn, rule) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("Failed to find hook rule: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if existingRule == nil { | 
					
						
							|  |  |  | 		return nil | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	_ = conn.DelRule(existingRule) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if err := conn.Flush(); err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("flush del hook rule: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // DelHooks is deleting the rules added to conventional chains to jump to tailscale chains. | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | func (n *nftablesRunner) DelHooks(logf logger.Logf) error { | 
					
						
							| 
									
										
										
										
											2023-08-18 18:41:28 +00:00
										 |  |  | 	conn := n.conn | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	for _, table := range n.getTables() { | 
					
						
							|  |  |  | 		inputChain, err := getChainFromTable(conn, table.Filter, "INPUT") | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			return fmt.Errorf("get INPUT chain: %w", err) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		err = delHookRule(conn, table.Filter, inputChain, chainNameInput) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			return fmt.Errorf("delhook: %w", err) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		forwardChain, err := getChainFromTable(conn, table.Filter, "FORWARD") | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			return fmt.Errorf("get FORWARD chain: %w", err) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		err = delHookRule(conn, table.Filter, forwardChain, chainNameForward) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			return fmt.Errorf("delhook: %w", err) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		postroutingChain, err := getChainFromTable(conn, table.Nat, "POSTROUTING") | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			return fmt.Errorf("get INPUT chain: %w", err) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		err = delHookRule(conn, table.Nat, postroutingChain, chainNamePostrouting) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			return fmt.Errorf("delhook: %w", err) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | 	return nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // maskof returns the mask of the given prefix in big endian bytes. | 
					
						
							|  |  |  | func maskof(pfx netip.Prefix) []byte { | 
					
						
							|  |  |  | 	mask := make([]byte, 4) | 
					
						
							|  |  |  | 	binary.BigEndian.PutUint32(mask, ^(uint32(0xffff_ffff) >> pfx.Bits())) | 
					
						
							|  |  |  | 	return mask | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // createRangeRule creates a rule that matches packets with source IP from the give | 
					
						
							|  |  |  | // range (like CGNAT range or ChromeOSVM range) and the interface is not the tunname, | 
					
						
							|  |  |  | // and makes the given decision. Only IPv4 is supported. | 
					
						
							|  |  |  | func createRangeRule( | 
					
						
							|  |  |  | 	table *nftables.Table, chain *nftables.Chain, | 
					
						
							|  |  |  | 	tunname string, rng netip.Prefix, decision expr.VerdictKind, | 
					
						
							|  |  |  | ) (*nftables.Rule, error) { | 
					
						
							|  |  |  | 	if rng.Addr().Is6() { | 
					
						
							|  |  |  | 		return nil, errors.New("IPv6 is not supported") | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	saddrExpr, err := newLoadSaddrExpr(nftables.TableFamilyIPv4, 1) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return nil, fmt.Errorf("newLoadSaddrExpr: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	netip := rng.Addr().AsSlice() | 
					
						
							|  |  |  | 	mask := maskof(rng) | 
					
						
							|  |  |  | 	rule := &nftables.Rule{ | 
					
						
							|  |  |  | 		Table: table, | 
					
						
							|  |  |  | 		Chain: chain, | 
					
						
							|  |  |  | 		Exprs: []expr.Any{ | 
					
						
							|  |  |  | 			&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, | 
					
						
							|  |  |  | 			&expr.Cmp{ | 
					
						
							|  |  |  | 				Op:       expr.CmpOpNeq, | 
					
						
							|  |  |  | 				Register: 1, | 
					
						
							|  |  |  | 				Data:     []byte(tunname), | 
					
						
							|  |  |  | 			}, | 
					
						
							|  |  |  | 			saddrExpr, | 
					
						
							|  |  |  | 			&expr.Bitwise{ | 
					
						
							|  |  |  | 				SourceRegister: 1, | 
					
						
							|  |  |  | 				DestRegister:   1, | 
					
						
							|  |  |  | 				Len:            4, | 
					
						
							|  |  |  | 				Mask:           mask, | 
					
						
							|  |  |  | 				Xor:            []byte{0x00, 0x00, 0x00, 0x00}, | 
					
						
							|  |  |  | 			}, | 
					
						
							|  |  |  | 			&expr.Cmp{ | 
					
						
							|  |  |  | 				Op:       expr.CmpOpEq, | 
					
						
							|  |  |  | 				Register: 1, | 
					
						
							|  |  |  | 				Data:     netip, | 
					
						
							|  |  |  | 			}, | 
					
						
							|  |  |  | 			&expr.Counter{}, | 
					
						
							|  |  |  | 			&expr.Verdict{ | 
					
						
							|  |  |  | 				Kind: decision, | 
					
						
							|  |  |  | 			}, | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return rule, nil | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // addReturnChromeOSVMRangeRule adds a rule to return if the source IP | 
					
						
							|  |  |  | // is in the ChromeOS VM range. | 
					
						
							|  |  |  | func addReturnChromeOSVMRangeRule(c *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error { | 
					
						
							|  |  |  | 	rule, err := createRangeRule(table, chain, tunname, tsaddr.ChromeOSVMRange(), expr.VerdictReturn) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("create rule: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	_ = c.AddRule(rule) | 
					
						
							|  |  |  | 	if err = c.Flush(); err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("add rule: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // addDropCGNATRangeRule adds a rule to drop if the source IP is in the | 
					
						
							|  |  |  | // CGNAT range. | 
					
						
							|  |  |  | func addDropCGNATRangeRule(c *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error { | 
					
						
							|  |  |  | 	rule, err := createRangeRule(table, chain, tunname, tsaddr.CGNATRange(), expr.VerdictDrop) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("create rule: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	_ = c.AddRule(rule) | 
					
						
							|  |  |  | 	if err = c.Flush(); err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("add rule: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // createSetSubnetRouteMarkRule creates a rule to set the subnet route | 
					
						
							|  |  |  | // mark if the packet is from the given interface. | 
					
						
							|  |  |  | func createSetSubnetRouteMarkRule(table *nftables.Table, chain *nftables.Chain, tunname string) (*nftables.Rule, error) { | 
					
						
							|  |  |  | 	hexTsFwmarkMaskNeg := getTailscaleFwmarkMaskNeg() | 
					
						
							|  |  |  | 	hexTSSubnetRouteMark := getTailscaleSubnetRouteMark() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	rule := &nftables.Rule{ | 
					
						
							|  |  |  | 		Table: table, | 
					
						
							|  |  |  | 		Chain: chain, | 
					
						
							|  |  |  | 		Exprs: []expr.Any{ | 
					
						
							|  |  |  | 			&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, | 
					
						
							|  |  |  | 			&expr.Cmp{ | 
					
						
							|  |  |  | 				Op:       expr.CmpOpEq, | 
					
						
							|  |  |  | 				Register: 1, | 
					
						
							|  |  |  | 				Data:     []byte(tunname), | 
					
						
							|  |  |  | 			}, | 
					
						
							|  |  |  | 			&expr.Counter{}, | 
					
						
							|  |  |  | 			&expr.Meta{Key: expr.MetaKeyMARK, Register: 1}, | 
					
						
							|  |  |  | 			&expr.Bitwise{ | 
					
						
							|  |  |  | 				SourceRegister: 1, | 
					
						
							|  |  |  | 				DestRegister:   1, | 
					
						
							|  |  |  | 				Len:            4, | 
					
						
							|  |  |  | 				Mask:           hexTsFwmarkMaskNeg, | 
					
						
							|  |  |  | 				Xor:            hexTSSubnetRouteMark, | 
					
						
							|  |  |  | 			}, | 
					
						
							|  |  |  | 			&expr.Meta{ | 
					
						
							|  |  |  | 				Key:            expr.MetaKeyMARK, | 
					
						
							|  |  |  | 				SourceRegister: true, | 
					
						
							|  |  |  | 				Register:       1, | 
					
						
							|  |  |  | 			}, | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return rule, nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // addSetSubnetRouteMarkRule adds a rule to set the subnet route mark | 
					
						
							|  |  |  | // if the packet is from the given interface. | 
					
						
							|  |  |  | func addSetSubnetRouteMarkRule(c *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error { | 
					
						
							|  |  |  | 	rule, err := createSetSubnetRouteMarkRule(table, chain, tunname) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("create rule: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	_ = c.AddRule(rule) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if err := c.Flush(); err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("add rule: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // createDropOutgoingPacketFromCGNATRangeRuleWithTunname creates a rule to drop | 
					
						
							|  |  |  | // outgoing packets from the CGNAT range. | 
					
						
							|  |  |  | func createDropOutgoingPacketFromCGNATRangeRuleWithTunname(table *nftables.Table, chain *nftables.Chain, tunname string) (*nftables.Rule, error) { | 
					
						
							|  |  |  | 	_, ipNet, err := net.ParseCIDR(tsaddr.CGNATRange().String()) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return nil, fmt.Errorf("parse cidr: %v", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	mask, err := hex.DecodeString(ipNet.Mask.String()) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return nil, fmt.Errorf("decode mask: %v", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	netip := ipNet.IP.Mask(ipNet.Mask).To4() | 
					
						
							|  |  |  | 	saddrExpr, err := newLoadSaddrExpr(nftables.TableFamilyIPv4, 1) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return nil, fmt.Errorf("newLoadSaddrExpr: %v", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	rule := &nftables.Rule{ | 
					
						
							|  |  |  | 		Table: table, | 
					
						
							|  |  |  | 		Chain: chain, | 
					
						
							|  |  |  | 		Exprs: []expr.Any{ | 
					
						
							|  |  |  | 			&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, | 
					
						
							|  |  |  | 			&expr.Cmp{ | 
					
						
							|  |  |  | 				Op:       expr.CmpOpEq, | 
					
						
							|  |  |  | 				Register: 1, | 
					
						
							|  |  |  | 				Data:     []byte(tunname), | 
					
						
							|  |  |  | 			}, | 
					
						
							|  |  |  | 			saddrExpr, | 
					
						
							|  |  |  | 			&expr.Bitwise{ | 
					
						
							|  |  |  | 				SourceRegister: 1, | 
					
						
							|  |  |  | 				DestRegister:   1, | 
					
						
							|  |  |  | 				Len:            4, | 
					
						
							|  |  |  | 				Mask:           mask, | 
					
						
							|  |  |  | 				Xor:            []byte{0x00, 0x00, 0x00, 0x00}, | 
					
						
							|  |  |  | 			}, | 
					
						
							|  |  |  | 			&expr.Cmp{ | 
					
						
							|  |  |  | 				Op:       expr.CmpOpEq, | 
					
						
							|  |  |  | 				Register: 1, | 
					
						
							|  |  |  | 				Data:     netip, | 
					
						
							|  |  |  | 			}, | 
					
						
							|  |  |  | 			&expr.Counter{}, | 
					
						
							|  |  |  | 			&expr.Verdict{ | 
					
						
							|  |  |  | 				Kind: expr.VerdictDrop, | 
					
						
							|  |  |  | 			}, | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return rule, nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // addDropOutgoingPacketFromCGNATRangeRuleWithTunname adds a rule to drop | 
					
						
							|  |  |  | // outgoing packets from the CGNAT range. | 
					
						
							|  |  |  | func addDropOutgoingPacketFromCGNATRangeRuleWithTunname(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error { | 
					
						
							|  |  |  | 	rule, err := createDropOutgoingPacketFromCGNATRangeRuleWithTunname(table, chain, tunname) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("create rule: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	_ = conn.AddRule(rule) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if err := conn.Flush(); err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("add rule: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // createAcceptOutgoingPacketRule creates a rule to accept outgoing packets | 
					
						
							|  |  |  | // from the given interface. | 
					
						
							|  |  |  | func createAcceptOutgoingPacketRule(table *nftables.Table, chain *nftables.Chain, tunname string) *nftables.Rule { | 
					
						
							|  |  |  | 	return &nftables.Rule{ | 
					
						
							|  |  |  | 		Table: table, | 
					
						
							|  |  |  | 		Chain: chain, | 
					
						
							|  |  |  | 		Exprs: []expr.Any{ | 
					
						
							|  |  |  | 			&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, | 
					
						
							|  |  |  | 			&expr.Cmp{ | 
					
						
							|  |  |  | 				Op:       expr.CmpOpEq, | 
					
						
							|  |  |  | 				Register: 1, | 
					
						
							|  |  |  | 				Data:     []byte(tunname), | 
					
						
							|  |  |  | 			}, | 
					
						
							|  |  |  | 			&expr.Counter{}, | 
					
						
							|  |  |  | 			&expr.Verdict{ | 
					
						
							|  |  |  | 				Kind: expr.VerdictAccept, | 
					
						
							|  |  |  | 			}, | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // addAcceptOutgoingPacketRule adds a rule to accept outgoing packets | 
					
						
							|  |  |  | // from the given interface. | 
					
						
							|  |  |  | func addAcceptOutgoingPacketRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error { | 
					
						
							|  |  |  | 	rule := createAcceptOutgoingPacketRule(table, chain, tunname) | 
					
						
							|  |  |  | 	_ = conn.AddRule(rule) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if err := conn.Flush(); err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("flush add rule: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-05 18:12:02 -05:00
										 |  |  | // createAcceptOnPortRule creates a rule to accept incoming packets to | 
					
						
							|  |  |  | // a given destination UDP port. | 
					
						
							|  |  |  | func createAcceptOnPortRule(table *nftables.Table, chain *nftables.Chain, port uint16) *nftables.Rule { | 
					
						
							|  |  |  | 	portBytes := make([]byte, 2) | 
					
						
							|  |  |  | 	binary.BigEndian.PutUint16(portBytes, port) | 
					
						
							|  |  |  | 	return &nftables.Rule{ | 
					
						
							|  |  |  | 		Table: table, | 
					
						
							|  |  |  | 		Chain: chain, | 
					
						
							|  |  |  | 		Exprs: []expr.Any{ | 
					
						
							|  |  |  | 			&expr.Meta{ | 
					
						
							|  |  |  | 				Key:      expr.MetaKeyL4PROTO, | 
					
						
							|  |  |  | 				Register: 1, | 
					
						
							|  |  |  | 			}, | 
					
						
							|  |  |  | 			&expr.Cmp{ | 
					
						
							|  |  |  | 				Op:       expr.CmpOpEq, | 
					
						
							|  |  |  | 				Register: 1, | 
					
						
							|  |  |  | 				Data:     []byte{unix.IPPROTO_UDP}, | 
					
						
							|  |  |  | 			}, | 
					
						
							|  |  |  | 			newLoadDportExpr(1), | 
					
						
							|  |  |  | 			&expr.Cmp{ | 
					
						
							|  |  |  | 				Op:       expr.CmpOpEq, | 
					
						
							|  |  |  | 				Register: 1, | 
					
						
							|  |  |  | 				Data:     portBytes, | 
					
						
							|  |  |  | 			}, | 
					
						
							|  |  |  | 			&expr.Counter{}, | 
					
						
							|  |  |  | 			&expr.Verdict{ | 
					
						
							|  |  |  | 				Kind: expr.VerdictAccept, | 
					
						
							|  |  |  | 			}, | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // addAcceptOnPortRule adds a rule to accept incoming packets to | 
					
						
							|  |  |  | // a given destination UDP port. | 
					
						
							|  |  |  | func addAcceptOnPortRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, port uint16) error { | 
					
						
							|  |  |  | 	rule := createAcceptOnPortRule(table, chain, port) | 
					
						
							|  |  |  | 	_ = conn.AddRule(rule) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if err := conn.Flush(); err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("flush add rule: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // addAcceptOnPortRule removes a rule to accept incoming packets to | 
					
						
							|  |  |  | // a given destination UDP port. | 
					
						
							|  |  |  | func removeAcceptOnPortRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, port uint16) error { | 
					
						
							|  |  |  | 	rule := createAcceptOnPortRule(table, chain, port) | 
					
						
							|  |  |  | 	rule, err := findRule(conn, rule) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("find rule: %v", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	_ = conn.DelRule(rule) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if err := conn.Flush(); err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("flush del rule: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // AddMagicsockPortRule adds a rule to nftables to allow incoming traffic on | 
					
						
							|  |  |  | // the specified UDP port, so magicsock can accept incoming connections. | 
					
						
							|  |  |  | // network must be either "udp4" or "udp6" - this determines whether the rule | 
					
						
							|  |  |  | // is added for IPv4 or IPv6. | 
					
						
							|  |  |  | func (n *nftablesRunner) AddMagicsockPortRule(port uint16, network string) error { | 
					
						
							|  |  |  | 	var filterTable *nftables.Table | 
					
						
							|  |  |  | 	switch network { | 
					
						
							|  |  |  | 	case "udp4": | 
					
						
							|  |  |  | 		filterTable = n.nft4.Filter | 
					
						
							|  |  |  | 	case "udp6": | 
					
						
							|  |  |  | 		filterTable = n.nft6.Filter | 
					
						
							|  |  |  | 	default: | 
					
						
							|  |  |  | 		return fmt.Errorf("unsupported network %s", network) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	inputChain, err := getChainFromTable(n.conn, filterTable, chainNameInput) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("get input chain: %v", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	err = addAcceptOnPortRule(n.conn, filterTable, inputChain, port) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("add accept on port rule: %v", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // DelMagicsockPortRule removes a rule added by AddMagicsockPortRule to accept | 
					
						
							|  |  |  | // incoming traffic on a particular UDP port. | 
					
						
							|  |  |  | // network must be either "udp4" or "udp6" - this determines whether the rule | 
					
						
							|  |  |  | // is removed for IPv4 or IPv6. | 
					
						
							|  |  |  | func (n *nftablesRunner) DelMagicsockPortRule(port uint16, network string) error { | 
					
						
							|  |  |  | 	var filterTable *nftables.Table | 
					
						
							|  |  |  | 	switch network { | 
					
						
							|  |  |  | 	case "udp4": | 
					
						
							|  |  |  | 		filterTable = n.nft4.Filter | 
					
						
							|  |  |  | 	case "udp6": | 
					
						
							|  |  |  | 		filterTable = n.nft6.Filter | 
					
						
							|  |  |  | 	default: | 
					
						
							|  |  |  | 		return fmt.Errorf("unsupported network %s", network) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	inputChain, err := getChainFromTable(n.conn, filterTable, chainNameInput) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("get input chain: %v", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	err = removeAcceptOnPortRule(n.conn, filterTable, inputChain, port) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("add accept on port rule: %v", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-24 18:15:03 -07:00
										 |  |  | // createAcceptIncomingPacketRule creates a rule to accept incoming packets to | 
					
						
							|  |  |  | // the given interface. | 
					
						
							|  |  |  | func createAcceptIncomingPacketRule(table *nftables.Table, chain *nftables.Chain, tunname string) *nftables.Rule { | 
					
						
							|  |  |  | 	return &nftables.Rule{ | 
					
						
							|  |  |  | 		Table: table, | 
					
						
							|  |  |  | 		Chain: chain, | 
					
						
							|  |  |  | 		Exprs: []expr.Any{ | 
					
						
							|  |  |  | 			&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, | 
					
						
							|  |  |  | 			&expr.Cmp{ | 
					
						
							|  |  |  | 				Op:       expr.CmpOpEq, | 
					
						
							|  |  |  | 				Register: 1, | 
					
						
							|  |  |  | 				Data:     []byte(tunname), | 
					
						
							|  |  |  | 			}, | 
					
						
							|  |  |  | 			&expr.Counter{}, | 
					
						
							|  |  |  | 			&expr.Verdict{ | 
					
						
							|  |  |  | 				Kind: expr.VerdictAccept, | 
					
						
							|  |  |  | 			}, | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func addAcceptIncomingPacketRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error { | 
					
						
							|  |  |  | 	rule := createAcceptIncomingPacketRule(table, chain, tunname) | 
					
						
							|  |  |  | 	_ = conn.AddRule(rule) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if err := conn.Flush(); err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("flush add rule: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | // AddBase adds some basic processing rules. | 
					
						
							|  |  |  | func (n *nftablesRunner) AddBase(tunname string) error { | 
					
						
							|  |  |  | 	if err := n.addBase4(tunname); err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("add base v4: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	if n.HasIPV6() { | 
					
						
							|  |  |  | 		if err := n.addBase6(tunname); err != nil { | 
					
						
							|  |  |  | 			return fmt.Errorf("add base v6: %w", err) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // addBase4 adds some basic IPv4 processing rules. | 
					
						
							|  |  |  | func (n *nftablesRunner) addBase4(tunname string) error { | 
					
						
							|  |  |  | 	conn := n.conn | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	inputChain, err := getChainFromTable(conn, n.nft4.Filter, chainNameInput) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("get input chain v4: %v", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	if err = addReturnChromeOSVMRangeRule(conn, n.nft4.Filter, inputChain, tunname); err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("add return chromeos vm range rule v4: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	if err = addDropCGNATRangeRule(conn, n.nft4.Filter, inputChain, tunname); err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("add drop cgnat range rule v4: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2023-08-24 18:15:03 -07:00
										 |  |  | 	if err = addAcceptIncomingPacketRule(conn, n.nft4.Filter, inputChain, tunname); err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("add accept incoming packet rule v4: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | 	forwardChain, err := getChainFromTable(conn, n.nft4.Filter, chainNameForward) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("get forward chain v4: %v", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if err = addSetSubnetRouteMarkRule(conn, n.nft4.Filter, forwardChain, tunname); err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("add set subnet route mark rule v4: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if err = addMatchSubnetRouteMarkRule(conn, n.nft4.Filter, forwardChain, Accept); err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("add match subnet route mark rule v4: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if err = addDropOutgoingPacketFromCGNATRangeRuleWithTunname(conn, n.nft4.Filter, forwardChain, tunname); err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("add drop outgoing packet from cgnat range rule v4: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if err = addAcceptOutgoingPacketRule(conn, n.nft4.Filter, forwardChain, tunname); err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("add accept outgoing packet rule v4: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if err = conn.Flush(); err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("flush base v4: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // addBase6 adds some basic IPv6 processing rules. | 
					
						
							|  |  |  | func (n *nftablesRunner) addBase6(tunname string) error { | 
					
						
							|  |  |  | 	conn := n.conn | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-24 18:15:03 -07:00
										 |  |  | 	inputChain, err := getChainFromTable(conn, n.nft6.Filter, chainNameInput) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("get input chain v4: %v", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	if err = addAcceptIncomingPacketRule(conn, n.nft6.Filter, inputChain, tunname); err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("add accept incoming packet rule v6: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | 	forwardChain, err := getChainFromTable(conn, n.nft6.Filter, chainNameForward) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("get forward chain v6: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if err = addSetSubnetRouteMarkRule(conn, n.nft6.Filter, forwardChain, tunname); err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("add set subnet route mark rule v6: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if err = addMatchSubnetRouteMarkRule(conn, n.nft6.Filter, forwardChain, Accept); err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("add match subnet route mark rule v6: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if err = addAcceptOutgoingPacketRule(conn, n.nft6.Filter, forwardChain, tunname); err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("add accept outgoing packet rule v6: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if err = conn.Flush(); err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("flush base v6: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // DelBase empties, but does not remove, custom Tailscale chains from | 
					
						
							|  |  |  | // netfilter via iptables. | 
					
						
							|  |  |  | func (n *nftablesRunner) DelBase() error { | 
					
						
							|  |  |  | 	conn := n.conn | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	for _, table := range n.getTables() { | 
					
						
							|  |  |  | 		inputChain, err := getChainFromTable(conn, table.Filter, chainNameInput) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			return fmt.Errorf("get input chain: %v", err) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		conn.FlushChain(inputChain) | 
					
						
							|  |  |  | 		forwardChain, err := getChainFromTable(conn, table.Filter, chainNameForward) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			return fmt.Errorf("get forward chain: %v", err) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		conn.FlushChain(forwardChain) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		postrouteChain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			return fmt.Errorf("get postrouting chain v4: %v", err) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		conn.FlushChain(postrouteChain) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return conn.Flush() | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // createMatchSubnetRouteMarkRule creates a rule that matches packets | 
					
						
							|  |  |  | // with the subnet route mark and takes the specified action. | 
					
						
							|  |  |  | func createMatchSubnetRouteMarkRule(table *nftables.Table, chain *nftables.Chain, action MatchDecision) (*nftables.Rule, error) { | 
					
						
							|  |  |  | 	hexTSFwmarkMask := getTailscaleFwmarkMask() | 
					
						
							|  |  |  | 	hexTSSubnetRouteMark := getTailscaleSubnetRouteMark() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	var endAction expr.Any | 
					
						
							|  |  |  | 	endAction = &expr.Verdict{Kind: expr.VerdictAccept} | 
					
						
							|  |  |  | 	if action == Masq { | 
					
						
							|  |  |  | 		endAction = &expr.Masq{} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	exprs := []expr.Any{ | 
					
						
							|  |  |  | 		&expr.Meta{Key: expr.MetaKeyMARK, Register: 1}, | 
					
						
							|  |  |  | 		&expr.Bitwise{ | 
					
						
							|  |  |  | 			SourceRegister: 1, | 
					
						
							|  |  |  | 			DestRegister:   1, | 
					
						
							|  |  |  | 			Len:            4, | 
					
						
							|  |  |  | 			Mask:           hexTSFwmarkMask, | 
					
						
							|  |  |  | 			Xor:            []byte{0x00, 0x00, 0x00, 0x00}, | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 		&expr.Cmp{ | 
					
						
							|  |  |  | 			Op:       expr.CmpOpEq, | 
					
						
							|  |  |  | 			Register: 1, | 
					
						
							|  |  |  | 			Data:     hexTSSubnetRouteMark, | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 		&expr.Counter{}, | 
					
						
							|  |  |  | 		endAction, | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	rule := &nftables.Rule{ | 
					
						
							|  |  |  | 		Table: table, | 
					
						
							|  |  |  | 		Chain: chain, | 
					
						
							|  |  |  | 		Exprs: exprs, | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return rule, nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // addMatchSubnetRouteMarkRule adds a rule that matches packets with | 
					
						
							|  |  |  | // the subnet route mark and takes the specified action. | 
					
						
							|  |  |  | func addMatchSubnetRouteMarkRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, action MatchDecision) error { | 
					
						
							|  |  |  | 	rule, err := createMatchSubnetRouteMarkRule(table, chain, action) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("create match subnet route mark rule: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	_ = conn.AddRule(rule) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if err := conn.Flush(); err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("flush add rule: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // AddSNATRule adds a netfilter rule to SNAT traffic destined for | 
					
						
							|  |  |  | // local subnets. | 
					
						
							|  |  |  | func (n *nftablesRunner) AddSNATRule() error { | 
					
						
							|  |  |  | 	conn := n.conn | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-05-14 08:51:53 +01:00
										 |  |  | 	for _, table := range n.getTables() { | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | 		chain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			return fmt.Errorf("get postrouting chain v4: %w", err) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		if err = addMatchSubnetRouteMarkRule(conn, table.Nat, chain, Masq); err != nil { | 
					
						
							|  |  |  | 			return fmt.Errorf("add match subnet route mark rule v4: %w", err) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if err := conn.Flush(); err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("flush add SNAT rule: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // DelSNATRule removes the netfilter rule to SNAT traffic destined for | 
					
						
							|  |  |  | // local subnets. An error is returned if the rule does not exist. | 
					
						
							|  |  |  | func (n *nftablesRunner) DelSNATRule() error { | 
					
						
							|  |  |  | 	conn := n.conn | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	hexTSFwmarkMask := getTailscaleFwmarkMask() | 
					
						
							|  |  |  | 	hexTSSubnetRouteMark := getTailscaleSubnetRouteMark() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	exprs := []expr.Any{ | 
					
						
							|  |  |  | 		&expr.Meta{Key: expr.MetaKeyMARK, Register: 1}, | 
					
						
							|  |  |  | 		&expr.Bitwise{ | 
					
						
							|  |  |  | 			SourceRegister: 1, | 
					
						
							|  |  |  | 			DestRegister:   1, | 
					
						
							|  |  |  | 			Len:            4, | 
					
						
							|  |  |  | 			Mask:           hexTSFwmarkMask, | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 		&expr.Cmp{ | 
					
						
							|  |  |  | 			Op:       expr.CmpOpEq, | 
					
						
							|  |  |  | 			Register: 1, | 
					
						
							|  |  |  | 			Data:     hexTSSubnetRouteMark, | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 		&expr.Counter{}, | 
					
						
							|  |  |  | 		&expr.Masq{}, | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-05-14 08:51:53 +01:00
										 |  |  | 	for _, table := range n.getTables() { | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | 		chain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			return fmt.Errorf("get postrouting chain v4: %w", err) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		rule := &nftables.Rule{ | 
					
						
							|  |  |  | 			Table: table.Nat, | 
					
						
							|  |  |  | 			Chain: chain, | 
					
						
							|  |  |  | 			Exprs: exprs, | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		SNATRule, err := findRule(conn, rule) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			return fmt.Errorf("find SNAT rule v4: %w", err) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-09-27 10:56:11 -07:00
										 |  |  | 		if SNATRule != nil { | 
					
						
							|  |  |  | 			_ = conn.DelRule(SNATRule) | 
					
						
							|  |  |  | 		} | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if err := conn.Flush(); err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("flush del SNAT rule: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-05-06 15:22:17 -07:00
										 |  |  | func nativeUint32(v uint32) []byte { | 
					
						
							|  |  |  | 	b := make([]byte, 4) | 
					
						
							|  |  |  | 	binary.NativeEndian.PutUint32(b, v) | 
					
						
							|  |  |  | 	return b | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func makeStatefulRuleExprs(tunname string) []expr.Any { | 
					
						
							|  |  |  | 	return []expr.Any{ | 
					
						
							|  |  |  | 		// Check if the output interface is the Tailscale interface by | 
					
						
							|  |  |  | 		// first loding the OIFNAME into register 1 and comparing it | 
					
						
							|  |  |  | 		// against our tunname. | 
					
						
							|  |  |  | 		// | 
					
						
							|  |  |  | 		// 'cmp' implicitly breaks from a rule if a comparison fails, | 
					
						
							|  |  |  | 		// so if we continue past this rule we know that the packet is | 
					
						
							|  |  |  | 		// going to our TUN. | 
					
						
							|  |  |  | 		&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, | 
					
						
							|  |  |  | 		&expr.Cmp{ | 
					
						
							| 
									
										
										
										
											2024-05-09 10:26:57 +01:00
										 |  |  | 			Op:       expr.CmpOpEq, | 
					
						
							| 
									
										
										
										
											2024-05-06 15:22:17 -07:00
										 |  |  | 			Register: 1, | 
					
						
							|  |  |  | 			Data:     []byte(tunname), | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		// Store the conntrack state in register 1 | 
					
						
							|  |  |  | 		&expr.Ct{ | 
					
						
							|  |  |  | 			Register: 1, | 
					
						
							|  |  |  | 			Key:      expr.CtKeySTATE, | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 		// Mask the state in register 1 to "hide" the ESTABLISHED and | 
					
						
							|  |  |  | 		// RELATED bits (which are expected and fine); if there are any | 
					
						
							|  |  |  | 		// other bits, we want them to remain. | 
					
						
							|  |  |  | 		// | 
					
						
							|  |  |  | 		// This operation is, in the kernel: | 
					
						
							|  |  |  | 		//    dst[i] = (src[i] & mask[i]) ^ xor[i] | 
					
						
							|  |  |  | 		// | 
					
						
							|  |  |  | 		// So, we can mask by setting the inverse of the bits we want | 
					
						
							|  |  |  | 		// to remove; i.e. ESTABLISHED = 0b00000010, RELATED = | 
					
						
							|  |  |  | 		// 0b00000100, so, if we assume an 8-bit state (in reality, | 
					
						
							|  |  |  | 		// it's 32-bit), we can mask with 0b11111001 to clear those | 
					
						
							|  |  |  | 		// bits and keep everything else (e.g. the INVALID bit which is | 
					
						
							|  |  |  | 		// 0b00000001). | 
					
						
							|  |  |  | 		// | 
					
						
							|  |  |  | 		// TODO(andrew-d): for now, let's also allow | 
					
						
							|  |  |  | 		// CtStateBitUNTRACKED, which is a state for packets that are not | 
					
						
							|  |  |  | 		// tracked (marked so explicitly with an iptables rule using | 
					
						
							|  |  |  | 		// --notrack); we should figure out if we want to allow this or not. | 
					
						
							|  |  |  | 		&expr.Bitwise{ | 
					
						
							|  |  |  | 			SourceRegister: 1, | 
					
						
							|  |  |  | 			DestRegister:   1, | 
					
						
							|  |  |  | 			Len:            4, | 
					
						
							|  |  |  | 			Mask: nativeUint32(^(0 | | 
					
						
							|  |  |  | 				expr.CtStateBitESTABLISHED | | 
					
						
							|  |  |  | 				expr.CtStateBitRELATED | | 
					
						
							|  |  |  | 				expr.CtStateBitUNTRACKED)), | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			// Xor is unused but must be specified | 
					
						
							|  |  |  | 			Xor: nativeUint32(0), | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 		// Compare against the expected state (0, i.e. no bits set | 
					
						
							|  |  |  | 		// other than maybe ESTABLISHED and RELATED). We want this | 
					
						
							|  |  |  | 		// comparison to fail if there are no bits set, so that this | 
					
						
							|  |  |  | 		// rule's evaluation stops and we don't fall through to the | 
					
						
							|  |  |  | 		// "Drop" verdict. | 
					
						
							|  |  |  | 		// | 
					
						
							|  |  |  | 		// For example, if the state is ESTABLISHED (and we want to | 
					
						
							|  |  |  | 		// break from this rule/accept this packet): | 
					
						
							|  |  |  | 		//   state     = ESTABLISHED | 
					
						
							|  |  |  | 		//   register1 = 0b0 (since the bitwise operation cleared the ESTABLISHED bit) | 
					
						
							|  |  |  | 		// | 
					
						
							|  |  |  | 		//   compare register1 (0b0) != 0: false | 
					
						
							|  |  |  | 		//   -> comparison implicitly breaks | 
					
						
							|  |  |  | 		//   -> continue to the next rule | 
					
						
							|  |  |  | 		// | 
					
						
							|  |  |  | 		// For example, if the state is NEW (and we want to continue to | 
					
						
							|  |  |  | 		// the next expression and thus drop this packet): | 
					
						
							|  |  |  | 		//   state     = NEW | 
					
						
							|  |  |  | 		//   register1 = 0b1000 | 
					
						
							|  |  |  | 		// | 
					
						
							|  |  |  | 		//   compare register1 (0b1000) != 0: true | 
					
						
							|  |  |  | 		//   -> comparison continues to next expr | 
					
						
							|  |  |  | 		&expr.Cmp{ | 
					
						
							|  |  |  | 			Op:       expr.CmpOpNeq, | 
					
						
							|  |  |  | 			Register: 1, | 
					
						
							|  |  |  | 			Data:     []byte{0, 0, 0, 0}, | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 		// If we get here, we know that this packet is going to our TUN | 
					
						
							|  |  |  | 		// device, and has a conntrack state set other than ESTABLISHED | 
					
						
							|  |  |  | 		// or RELATED. We thus count and drop the packet. | 
					
						
							|  |  |  | 		&expr.Counter{}, | 
					
						
							|  |  |  | 		&expr.Verdict{Kind: expr.VerdictDrop}, | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// TODO(andrew-d): iptables-nft writes a rule that dumps as: | 
					
						
							|  |  |  | 	// | 
					
						
							|  |  |  | 	//	match name conntrack rev 3 | 
					
						
							|  |  |  | 	// | 
					
						
							|  |  |  | 	// I think this is using expr.Match against the following struct | 
					
						
							|  |  |  | 	// (xt_conntrack_mtinfo3): | 
					
						
							|  |  |  | 	// | 
					
						
							|  |  |  | 	//	https://github.com/torvalds/linux/blob/master/include/uapi/linux/netfilter/xt_conntrack.h#L64-L77 | 
					
						
							|  |  |  | 	// | 
					
						
							|  |  |  | 	// We could probably do something similar here, but I'm not sure if | 
					
						
							|  |  |  | 	// there's any advantage. Below is an example Match statement if we | 
					
						
							|  |  |  | 	// decide to do that, based on dumping the rule that iptables-nft | 
					
						
							|  |  |  | 	// generates: | 
					
						
							|  |  |  | 	// | 
					
						
							|  |  |  | 	//	_ = expr.Match{ | 
					
						
							|  |  |  | 	//		Name: "conntrack", | 
					
						
							|  |  |  | 	//		Rev:  3, | 
					
						
							|  |  |  | 	//		Info: &xt.ConntrackMtinfo3{ | 
					
						
							|  |  |  | 	//			ConntrackMtinfo2: xt.ConntrackMtinfo2{ | 
					
						
							|  |  |  | 	//				ConntrackMtinfoBase: xt.ConntrackMtinfoBase{ | 
					
						
							|  |  |  | 	//					MatchFlags:  xt.ConntrackState, | 
					
						
							|  |  |  | 	//					InvertFlags: xt.ConntrackState, | 
					
						
							|  |  |  | 	//				}, | 
					
						
							|  |  |  | 	//				// Mask the state to remove ESTABLISHED and | 
					
						
							|  |  |  | 	//				// RELATED before comparing. | 
					
						
							|  |  |  | 	//				StateMask: expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED, | 
					
						
							|  |  |  | 	//			}, | 
					
						
							|  |  |  | 	//		}, | 
					
						
							|  |  |  | 	//	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // AddStatefulRule adds a netfilter rule for stateful packet filtering using | 
					
						
							|  |  |  | // conntrack. | 
					
						
							|  |  |  | func (n *nftablesRunner) AddStatefulRule(tunname string) error { | 
					
						
							|  |  |  | 	conn := n.conn | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	exprs := makeStatefulRuleExprs(tunname) | 
					
						
							|  |  |  | 	for _, table := range n.getTables() { | 
					
						
							|  |  |  | 		chain, err := getChainFromTable(conn, table.Filter, chainNameForward) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			return fmt.Errorf("get forward chain: %w", err) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		// First, find the 'accept' rule that we want to insert our rule before. | 
					
						
							|  |  |  | 		acceptRule := createAcceptOutgoingPacketRule(table.Filter, chain, tunname) | 
					
						
							|  |  |  | 		rule, err := findRule(conn, acceptRule) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			return fmt.Errorf("find accept rule: %w", err) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		conn.InsertRule(&nftables.Rule{ | 
					
						
							|  |  |  | 			Table: table.Filter, | 
					
						
							|  |  |  | 			Chain: chain, | 
					
						
							|  |  |  | 			Exprs: exprs, | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			// Specifying Position in an Insert operation means to | 
					
						
							|  |  |  | 			// insert this rule before the specified rule. | 
					
						
							|  |  |  | 			Position: rule.Handle, | 
					
						
							|  |  |  | 		}) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if err := conn.Flush(); err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("flush add stateful rule: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // DelStatefulRule removes the netfilter rule for stateful packet filtering | 
					
						
							|  |  |  | // using conntrack. | 
					
						
							|  |  |  | func (n *nftablesRunner) DelStatefulRule(tunname string) error { | 
					
						
							|  |  |  | 	conn := n.conn | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	exprs := makeStatefulRuleExprs(tunname) | 
					
						
							|  |  |  | 	for _, table := range n.getTables() { | 
					
						
							|  |  |  | 		chain, err := getChainFromTable(conn, table.Filter, chainNameForward) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			return fmt.Errorf("get forward chain: %w", err) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		rule, err := findRule(conn, &nftables.Rule{ | 
					
						
							| 
									
										
										
										
											2024-05-09 13:32:39 -04:00
										 |  |  | 			Table: table.Filter, | 
					
						
							| 
									
										
										
										
											2024-05-06 15:22:17 -07:00
										 |  |  | 			Chain: chain, | 
					
						
							|  |  |  | 			Exprs: exprs, | 
					
						
							|  |  |  | 		}) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			return fmt.Errorf("find stateful rule: %w", err) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		if rule != nil { | 
					
						
							|  |  |  | 			conn.DelRule(rule) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	if err := conn.Flush(); err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("flush del stateful rule: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-18 18:41:28 +00:00
										 |  |  | // cleanupChain removes a jump rule from hookChainName to tsChainName, and then | 
					
						
							|  |  |  | // the entire chain tsChainName. Errors are logged, but attempts to remove both | 
					
						
							|  |  |  | // the jump rule and chain continue even if one errors. | 
					
						
							|  |  |  | func cleanupChain(logf logger.Logf, conn *nftables.Conn, table *nftables.Table, hookChainName, tsChainName string) { | 
					
						
							|  |  |  | 	// remove the jump first, before removing the jump destination. | 
					
						
							|  |  |  | 	defaultChain, err := getChainFromTable(conn, table, hookChainName) | 
					
						
							|  |  |  | 	if err != nil && !errors.Is(err, errorChainNotFound{table.Name, hookChainName}) { | 
					
						
							|  |  |  | 		logf("cleanup: did not find default chain: %s", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	if !errors.Is(err, errorChainNotFound{table.Name, hookChainName}) { | 
					
						
							|  |  |  | 		// delete hook in convention chain | 
					
						
							|  |  |  | 		_ = delHookRule(conn, table, defaultChain, tsChainName) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	tsChain, err := getChainFromTable(conn, table, tsChainName) | 
					
						
							|  |  |  | 	if err != nil && !errors.Is(err, errorChainNotFound{table.Name, tsChainName}) { | 
					
						
							|  |  |  | 		logf("cleanup: did not find ts-chain: %s", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if tsChain != nil { | 
					
						
							|  |  |  | 		// flush and delete ts-chain | 
					
						
							|  |  |  | 		conn.FlushChain(tsChain) | 
					
						
							|  |  |  | 		conn.DelChain(tsChain) | 
					
						
							|  |  |  | 		err = conn.Flush() | 
					
						
							|  |  |  | 		logf("cleanup: delete and flush chain %s: %s", tsChainName, err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | // NfTablesCleanUp removes all Tailscale added nftables rules. | 
					
						
							|  |  |  | // Any errors that occur are logged to the provided logf. | 
					
						
							|  |  |  | func NfTablesCleanUp(logf logger.Logf) { | 
					
						
							|  |  |  | 	conn, err := nftables.New() | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							| 
									
										
										
										
											2023-08-18 18:41:28 +00:00
										 |  |  | 		logf("cleanup: nftables connection: %s", err) | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	tables, err := conn.ListTables() // both v4 and v6 | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							| 
									
										
										
										
											2023-08-18 18:41:28 +00:00
										 |  |  | 		logf("cleanup: list tables: %s", err) | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	for _, table := range tables { | 
					
						
							| 
									
										
										
										
											2023-08-18 18:41:28 +00:00
										 |  |  | 		// These table names were used briefly in 1.48.0. | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | 		if table.Name == "ts-filter" || table.Name == "ts-nat" { | 
					
						
							|  |  |  | 			conn.DelTable(table) | 
					
						
							|  |  |  | 			if err := conn.Flush(); err != nil { | 
					
						
							| 
									
										
										
										
											2023-08-18 18:41:28 +00:00
										 |  |  | 				logf("cleanup: flush delete table %s: %s", table.Name, err) | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | 			} | 
					
						
							|  |  |  | 		} | 
					
						
							| 
									
										
										
										
											2023-08-18 18:41:28 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | 		if table.Name == "filter" { | 
					
						
							|  |  |  | 			cleanupChain(logf, conn, table, "INPUT", chainNameInput) | 
					
						
							|  |  |  | 			cleanupChain(logf, conn, table, "FORWARD", chainNameForward) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		if table.Name == "nat" { | 
					
						
							|  |  |  | 			cleanupChain(logf, conn, table, "POSTROUTING", chainNamePostrouting) | 
					
						
							|  |  |  | 		} | 
					
						
							| 
									
										
										
										
											2023-06-19 20:05:14 +00:00
										 |  |  | 	} | 
					
						
							|  |  |  | } |