wgengine/router: use netlink for ip rules on Linux

Using temporary netlink fork in github.com/tailscale/netlink until we
get the necessary changes upstream in either vishvananda/netlink
or jsimonetti/rtnetlink.

Updates #391

Change-Id: I6e1de96cf0750ccba53dabff670aca0c56dffb7c
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick
2021-10-28 15:22:03 -07:00
committed by Brad Fitzpatrick
parent 5dc5bd8d20
commit ff1954cfd9
5 changed files with 162 additions and 36 deletions

View File

@@ -18,7 +18,7 @@ import (
"github.com/coreos/go-iptables/iptables"
"github.com/go-multierror/multierror"
"github.com/vishvananda/netlink"
"github.com/tailscale/netlink"
"golang.org/x/sys/unix"
"golang.org/x/time/rate"
"golang.zx2c4.com/wireguard/tun"
@@ -604,7 +604,11 @@ func (r *linuxRouter) addRouteDef(routeDef []string, cidr netaddr.IPPrefix) erro
return err
}
var errESRCH error = syscall.ESRCH
var (
errESRCH error = syscall.ESRCH
errENOENT error = syscall.ENOENT
errEEXIST error = syscall.EEXIST
)
// delRoute removes the route for cidr pointing to the tunnel
// interface. Fails if the route doesn't exist, or if removing the
@@ -766,6 +770,16 @@ func (f addrFamily) dashArg() string {
panic("illegal")
}
func (f addrFamily) netlinkInt() int {
switch f {
case 4:
return netlink.FAMILY_V4
case 6:
return netlink.FAMILY_V6
}
panic("illegal")
}
func (r *linuxRouter) addrFamilies() []addrFamily {
if r.v6Available {
return []addrFamily{v4, v6}
@@ -878,7 +892,7 @@ var ipRules = []netlink.Rule{
{
Priority: 5250,
Mark: tailscaleBypassMarkNum,
Table: 0, // unreachable
Type: unix.RTN_UNREACHABLE,
},
// If we get to this point, capture all packets and send them
// through to the tailscale route table. For apps other than us
@@ -898,7 +912,34 @@ func (r *linuxRouter) justAddIPRules() error {
if !r.ipRuleAvailable {
return nil
}
if r.useIPCommand() {
return r.addIPRulesWithIPCommand()
}
var errAcc error
for _, family := range r.addrFamilies() {
for _, ru := range ipRules {
// Note: r is a value type here; safe to mutate it.
ru.Family = family.netlinkInt()
ru.Mask = -1
ru.Goto = -1
ru.SuppressIfgroup = -1
ru.SuppressPrefixlen = -1
ru.Flow = -1
err := netlink.RuleAdd(&ru)
if errors.Is(err, errEEXIST) {
// Ignore dups.
continue
}
if err != nil && errAcc == nil {
errAcc = err
}
}
}
return errAcc
}
func (r *linuxRouter) addIPRulesWithIPCommand() error {
rg := newRunGroup(nil, r.cmd)
for _, family := range r.addrFamilies() {
@@ -913,7 +954,8 @@ func (r *linuxRouter) justAddIPRules() error {
}
if r.Table != 0 {
args = append(args, "table", mustRouteTable(r.Table).ipCmdArg())
} else {
}
if r.Type == unix.RTN_UNREACHABLE {
args = append(args, "type", "unreachable")
}
rg.Run(args...)
@@ -940,7 +982,39 @@ func (r *linuxRouter) delIPRules() error {
if !r.ipRuleAvailable {
return nil
}
if r.useIPCommand() {
return r.delIPRulesWithIPCommand()
}
var errAcc error
for _, family := range r.addrFamilies() {
for _, ru := range ipRules {
// Note: r is a value type here; safe to mutate it.
// When deleting rules, we want to be a bit specific (mention which
// table we were routing to) but not *too* specific (fwmarks, etc).
// That leaves us some flexibility to change these values in later
// versions without having ongoing hacks for every possible
// combination.
ru.Family = family.netlinkInt()
ru.Mark = -1
ru.Mask = -1
ru.Goto = -1
ru.SuppressIfgroup = -1
ru.SuppressPrefixlen = -1
err := netlink.RuleDel(&ru)
if errors.Is(err, errENOENT) {
// Didn't exist to begin with.
continue
}
if err != nil && errAcc == nil {
errAcc = err
}
}
}
return errAcc
}
func (r *linuxRouter) delIPRulesWithIPCommand() error {
// Error codes: 'ip rule' returns error code 2 if the rule is a
// duplicate (add) or not found (del). It returns a different code
// for syntax errors. This is also true of busybox.

View File

@@ -654,62 +654,110 @@ func createTestTUN(t *testing.T) tun.Device {
return tun
}
func TestDelRouteIdempotent(t *testing.T) {
type linuxTest struct {
tun tun.Device
mon *monitor.Mon
r *linuxRouter
logOutput tstest.MemLogger
}
func (lt *linuxTest) Close() error {
if lt.tun != nil {
lt.tun.Close()
}
if lt.mon != nil {
lt.mon.Close()
}
return nil
}
func newLinuxRootTest(t *testing.T) *linuxTest {
if os.Getuid() != 0 {
t.Skip("test requires root")
}
tun := createTestTUN(t)
defer tun.Close()
var logOutput tstest.MemLogger
logf := logOutput.Logf
lt := new(linuxTest)
lt.tun = createTestTUN(t)
logf := lt.logOutput.Logf
mon, err := monitor.New(logger.Discard)
if err != nil {
lt.Close()
t.Fatal(err)
}
mon.Start()
defer mon.Close()
lt.mon = mon
r, err := newUserspaceRouter(logf, tun, mon)
r, err := newUserspaceRouter(logf, lt.tun, mon)
if err != nil {
lt.Close()
t.Fatal(err)
}
lr := r.(*linuxRouter)
if err := lr.upInterface(); err != nil {
lt.Close()
t.Fatal(err)
}
lt.r = lr
return lt
}
func TestDelRouteIdempotent(t *testing.T) {
lt := newLinuxRootTest(t)
defer lt.Close()
for _, s := range []string{
"192.0.2.0/24", // RFC 5737
"2001:DB8::/32", // RFC 3849
} {
cidr := netaddr.MustParseIPPrefix(s)
if err := lr.addRoute(cidr); err != nil {
t.Fatal(err)
if err := lt.r.addRoute(cidr); err != nil {
t.Error(err)
continue
}
for i := 0; i < 2; i++ {
if err := lr.delRoute(cidr); err != nil {
t.Fatalf("delRoute(i=%d): %v", i, err)
if err := lt.r.delRoute(cidr); err != nil {
t.Errorf("delRoute(i=%d): %v", i, err)
}
}
}
wantSubs := map[string]int{
"warning: tried to delete route 192.0.2.0/24 but it was already gone; ignoring error": 1,
"warning: tried to delete route 2001:db8::/32 but it was already gone; ignoring error": 1,
}
out := logOutput.String()
for sub, want := range wantSubs {
got := strings.Count(out, sub)
if got != want {
t.Errorf("log output substring %q occurred %d time; want %d", sub, got, want)
}
}
if t.Failed() {
out := lt.logOutput.String()
t.Logf("Log output:\n%s", out)
}
}
func TestAddRemoveRules(t *testing.T) {
lt := newLinuxRootTest(t)
defer lt.Close()
r := lt.r
step := func(name string, f func() error) {
t.Logf("Doing %v ...", name)
if err := f(); err != nil {
t.Fatalf("%s: %v", name, err)
}
rules, err := netlink.RuleList(netlink.FAMILY_ALL)
if err != nil {
t.Fatal(err)
}
for _, r := range rules {
if r.Priority >= 5000 && r.Priority <= 5999 {
t.Logf("Rule: %+v", r)
}
}
}
step("init_del_and_add", r.addIPRules)
step("dup_add", r.justAddIPRules)
step("del", r.delIPRules)
step("dup_del", r.delIPRules)
}
func TestDebugListLinks(t *testing.T) {
links, err := netlink.LinkList()
if err != nil {