From ee94447b4f259c23be023a2cd1b9a2d4fbc0fcba Mon Sep 17 00:00:00 2001 From: Nick Khyl Date: Fri, 14 Mar 2025 15:17:26 -0500 Subject: [PATCH] ipn/ipnlocal: unconfigure wgengine when switching profiles LocalBackend transitions to ipn.NoState when switching to a different (or new) profile. When this happens, we should unconfigure wgengine to clear routes, DNS configuration, firewall rules that block all traffic except to the exit node, etc. In this PR, we update (*LocalBackend).enterStateLockedOnEntry to do just that. Fixes #15316 Updates tailscale/corp#23967 Signed-off-by: Nick Khyl --- ipn/ipnlocal/local.go | 8 +- ipn/ipnlocal/local_test.go | 24 +- ipn/ipnlocal/state_test.go | 474 +++++++++++++++++++++++++++++++++++++ 3 files changed, 494 insertions(+), 12 deletions(-) diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index f866527d1..5fd1127d8 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -5706,13 +5706,15 @@ func (b *LocalBackend) enterStateLockedOnEntry(newState ipn.State, unlock unlock } b.blockEngineUpdates(true) fallthrough - case ipn.Stopped: + case ipn.Stopped, ipn.NoState: + // Unconfigure the engine if it has stopped (WantRunning is set to false) + // or if we've switched to a different profile and the state is unknown. err := b.e.Reconfig(&wgcfg.Config{}, &router.Config{}, &dns.Config{}) if err != nil { b.logf("Reconfig(down): %v", err) } - if authURL == "" { + if newState == ipn.Stopped && authURL == "" { systemd.Status("Stopped; run 'tailscale up' to log in") } case ipn.Starting, ipn.NeedsMachineAuth: @@ -5726,8 +5728,6 @@ func (b *LocalBackend) enterStateLockedOnEntry(newState ipn.State, unlock unlock addrStrs = append(addrStrs, p.Addr().String()) } systemd.Status("Connected; %s; %s", activeLogin, strings.Join(addrStrs, " ")) - case ipn.NoState: - // Do nothing. default: b.logf("[unexpected] unknown newState %#v", newState) } diff --git a/ipn/ipnlocal/local_test.go b/ipn/ipnlocal/local_test.go index 35977e679..9529bbe3b 100644 --- a/ipn/ipnlocal/local_test.go +++ b/ipn/ipnlocal/local_test.go @@ -4396,19 +4396,27 @@ func TestNotificationTargetMatch(t *testing.T) { type newTestControlFn func(tb testing.TB, opts controlclient.Options) controlclient.Client func newLocalBackendWithTestControl(t *testing.T, enableLogging bool, newControl newTestControlFn) *LocalBackend { + return newLocalBackendWithSysAndTestControl(t, enableLogging, new(tsd.System), newControl) +} + +func newLocalBackendWithSysAndTestControl(t *testing.T, enableLogging bool, sys *tsd.System, newControl newTestControlFn) *LocalBackend { logf := logger.Discard if enableLogging { logf = tstest.WhileTestRunningLogger(t) } - sys := new(tsd.System) - store := new(mem.Store) - sys.Set(store) - e, err := wgengine.NewFakeUserspaceEngine(logf, sys.Set, sys.HealthTracker(), sys.UserMetricsRegistry()) - if err != nil { - t.Fatalf("NewFakeUserspaceEngine: %v", err) + + if _, hasStore := sys.StateStore.GetOK(); !hasStore { + store := new(mem.Store) + sys.Set(store) + } + if _, hasEngine := sys.Engine.GetOK(); !hasEngine { + e, err := wgengine.NewFakeUserspaceEngine(logf, sys.Set, sys.HealthTracker(), sys.UserMetricsRegistry()) + if err != nil { + t.Fatalf("NewFakeUserspaceEngine: %v", err) + } + t.Cleanup(e.Close) + sys.Set(e) } - t.Cleanup(e.Close) - sys.Set(e) b, err := NewLocalBackend(logf, logid.PublicID{}, sys, 0) if err != nil { diff --git a/ipn/ipnlocal/state_test.go b/ipn/ipnlocal/state_test.go index 1b3b43af6..7a767da31 100644 --- a/ipn/ipnlocal/state_test.go +++ b/ipn/ipnlocal/state_test.go @@ -5,26 +5,46 @@ package ipnlocal import ( "context" + "errors" + "net/netip" + "strings" "sync" "sync/atomic" "testing" "time" qt "github.com/frankban/quicktest" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "tailscale.com/control/controlclient" "tailscale.com/envknob" "tailscale.com/ipn" + "tailscale.com/ipn/ipnstate" "tailscale.com/ipn/store/mem" + "tailscale.com/net/dns" + "tailscale.com/net/netmon" + "tailscale.com/net/packet" + "tailscale.com/net/tsdial" "tailscale.com/tailcfg" "tailscale.com/tsd" "tailscale.com/tstest" + "tailscale.com/types/dnstype" "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/logid" "tailscale.com/types/netmap" "tailscale.com/types/persist" + "tailscale.com/types/preftype" + "tailscale.com/util/dnsname" + "tailscale.com/util/mak" + "tailscale.com/util/must" "tailscale.com/wgengine" + "tailscale.com/wgengine/filter" + "tailscale.com/wgengine/magicsock" + "tailscale.com/wgengine/router" + "tailscale.com/wgengine/wgcfg" + "tailscale.com/wgengine/wgint" ) // notifyThrottler receives notifications from an ipn.Backend, blocking @@ -170,6 +190,14 @@ func (cc *mockControl) send(err error, url string, loginFinished bool, nm *netma } } +func (cc *mockControl) authenticated(nm *netmap.NetworkMap) { + if selfUser, ok := nm.UserProfiles[nm.SelfNode.User()]; ok { + cc.persist.UserProfile = *selfUser.AsStruct() + } + cc.persist.NodeID = nm.SelfNode.StableID() + cc.send(nil, "", true, nm) +} + // called records that a particular function name was called. func (cc *mockControl) called(s string) { cc.mu.Lock() @@ -1072,3 +1100,449 @@ func TestWGEngineStatusRace(t *testing.T) { wg.Wait() wantState(ipn.Running) } + +// TestEngineReconfigOnStateChange verifies that wgengine is properly reconfigured +// when the LocalBackend's state changes, such as when the user logs in, switches +// profiles, or disconnects from Tailscale. +func TestEngineReconfigOnStateChange(t *testing.T) { + enableLogging := false + connect := &ipn.MaskedPrefs{Prefs: ipn.Prefs{WantRunning: true}, WantRunningSet: true} + disconnect := &ipn.MaskedPrefs{Prefs: ipn.Prefs{WantRunning: false}, WantRunningSet: true} + node1 := testNetmapForNode(1, "node-1", []netip.Prefix{netip.MustParsePrefix("100.64.1.1/32")}) + node2 := testNetmapForNode(2, "node-2", []netip.Prefix{netip.MustParsePrefix("100.64.1.2/32")}) + routesWithQuad100 := func(extra ...netip.Prefix) []netip.Prefix { + return append(extra, netip.MustParsePrefix("100.100.100.100/32")) + } + hostsFor := func(nm *netmap.NetworkMap) map[dnsname.FQDN][]netip.Addr { + var hosts map[dnsname.FQDN][]netip.Addr + appendNode := func(n tailcfg.NodeView) { + addrs := make([]netip.Addr, 0, n.Addresses().Len()) + for _, addr := range n.Addresses().All() { + addrs = append(addrs, addr.Addr()) + } + mak.Set(&hosts, must.Get(dnsname.ToFQDN(n.Name())), addrs) + } + if nm != nil && nm.SelfNode.Valid() { + appendNode(nm.SelfNode) + } + for _, n := range nm.Peers { + appendNode(n) + } + return hosts + } + + tests := []struct { + name string + steps func(*testing.T, *LocalBackend, func() *mockControl) + wantState ipn.State + wantCfg *wgcfg.Config + wantRouterCfg *router.Config + wantDNSCfg *dns.Config + }{ + { + name: "Initial", + // The configs are nil until the the LocalBackend is started. + wantState: ipn.NoState, + wantCfg: nil, + wantRouterCfg: nil, + wantDNSCfg: nil, + }, + { + name: "Start", + steps: func(t *testing.T, lb *LocalBackend, _ func() *mockControl) { + mustDo(t)(lb.Start(ipn.Options{})) + }, + // Once started, all configs must be reset and have their zero values. + wantState: ipn.NeedsLogin, + wantCfg: &wgcfg.Config{}, + wantRouterCfg: &router.Config{}, + wantDNSCfg: &dns.Config{}, + }, + { + name: "Start/Connect", + steps: func(t *testing.T, lb *LocalBackend, _ func() *mockControl) { + mustDo(t)(lb.Start(ipn.Options{})) + mustDo2(t)(lb.EditPrefs(connect)) + }, + // Same if WantRunning is true, but the auth is not completed yet. + wantState: ipn.NeedsLogin, + wantCfg: &wgcfg.Config{}, + wantRouterCfg: &router.Config{}, + wantDNSCfg: &dns.Config{}, + }, + { + name: "Start/Connect/Login", + steps: func(t *testing.T, lb *LocalBackend, cc func() *mockControl) { + mustDo(t)(lb.Start(ipn.Options{})) + mustDo2(t)(lb.EditPrefs(connect)) + cc().authenticated(node1) + }, + // After the auth is completed, the configs must be updated to reflect the node's netmap. + wantState: ipn.Starting, + wantCfg: &wgcfg.Config{ + Name: "tailscale", + NodeID: node1.SelfNode.StableID(), + Peers: []wgcfg.Peer{}, + Addresses: node1.SelfNode.Addresses().AsSlice(), + }, + wantRouterCfg: &router.Config{ + SNATSubnetRoutes: true, + NetfilterMode: preftype.NetfilterOn, + LocalAddrs: node1.SelfNode.Addresses().AsSlice(), + Routes: routesWithQuad100(), + }, + wantDNSCfg: &dns.Config{ + Routes: map[dnsname.FQDN][]*dnstype.Resolver{}, + Hosts: hostsFor(node1), + }, + }, + { + name: "Start/Connect/Login/Disconnect", + steps: func(t *testing.T, lb *LocalBackend, cc func() *mockControl) { + mustDo(t)(lb.Start(ipn.Options{})) + mustDo2(t)(lb.EditPrefs(connect)) + cc().authenticated(node1) + mustDo2(t)(lb.EditPrefs(disconnect)) + }, + // After disconnecting, all configs must be reset and have their zero values. + wantState: ipn.Stopped, + wantCfg: &wgcfg.Config{}, + wantRouterCfg: &router.Config{}, + wantDNSCfg: &dns.Config{}, + }, + { + name: "Start/Connect/Login/NewProfile", + steps: func(t *testing.T, lb *LocalBackend, cc func() *mockControl) { + mustDo(t)(lb.Start(ipn.Options{})) + mustDo2(t)(lb.EditPrefs(connect)) + cc().authenticated(node1) + mustDo(t)(lb.NewProfile()) + }, + // After switching to a new, empty profile, all configs should be reset + // and have their zero values until the auth is completed. + wantState: ipn.NeedsLogin, + wantCfg: &wgcfg.Config{}, + wantRouterCfg: &router.Config{}, + wantDNSCfg: &dns.Config{}, + }, + { + name: "Start/Connect/Login/NewProfile/Login", + steps: func(t *testing.T, lb *LocalBackend, cc func() *mockControl) { + mustDo(t)(lb.Start(ipn.Options{})) + mustDo2(t)(lb.EditPrefs(connect)) + cc().authenticated(node1) + mustDo(t)(lb.NewProfile()) + mustDo2(t)(lb.EditPrefs(connect)) + cc().authenticated(node2) + }, + // Once the auth is completed, the configs must be updated to reflect the node's netmap. + wantState: ipn.Starting, + wantCfg: &wgcfg.Config{ + Name: "tailscale", + NodeID: node2.SelfNode.StableID(), + Peers: []wgcfg.Peer{}, + Addresses: node2.SelfNode.Addresses().AsSlice(), + }, + wantRouterCfg: &router.Config{ + SNATSubnetRoutes: true, + NetfilterMode: preftype.NetfilterOn, + LocalAddrs: node2.SelfNode.Addresses().AsSlice(), + Routes: routesWithQuad100(), + }, + wantDNSCfg: &dns.Config{ + Routes: map[dnsname.FQDN][]*dnstype.Resolver{}, + Hosts: hostsFor(node2), + }, + }, + { + name: "Start/Connect/Login/SwitchProfile", + steps: func(t *testing.T, lb *LocalBackend, cc func() *mockControl) { + mustDo(t)(lb.Start(ipn.Options{})) + mustDo2(t)(lb.EditPrefs(connect)) + cc().authenticated(node1) + profileID := lb.CurrentProfile().ID() + mustDo(t)(lb.NewProfile()) + cc().authenticated(node2) + mustDo(t)(lb.SwitchProfile(profileID)) + }, + // After switching to an existing profile, all configs must be reset + // and have their zero values until the (non-interactive) login is completed. + wantState: ipn.NoState, + wantCfg: &wgcfg.Config{}, + wantRouterCfg: &router.Config{}, + wantDNSCfg: &dns.Config{}, + }, + { + name: "Start/Connect/Login/SwitchProfile/NonInteractiveLogin", + steps: func(t *testing.T, lb *LocalBackend, cc func() *mockControl) { + mustDo(t)(lb.Start(ipn.Options{})) + mustDo2(t)(lb.EditPrefs(connect)) + cc().authenticated(node1) + profileID := lb.CurrentProfile().ID() + mustDo(t)(lb.NewProfile()) + cc().authenticated(node2) + mustDo(t)(lb.SwitchProfile(profileID)) + cc().authenticated(node1) // complete the login + }, + // After switching profiles and completing the auth, the configs + // must be updated to reflect the node's netmap. + wantState: ipn.Starting, + wantCfg: &wgcfg.Config{ + Name: "tailscale", + NodeID: node1.SelfNode.StableID(), + Peers: []wgcfg.Peer{}, + Addresses: node1.SelfNode.Addresses().AsSlice(), + }, + wantRouterCfg: &router.Config{ + SNATSubnetRoutes: true, + NetfilterMode: preftype.NetfilterOn, + LocalAddrs: node1.SelfNode.Addresses().AsSlice(), + Routes: routesWithQuad100(), + }, + wantDNSCfg: &dns.Config{ + Routes: map[dnsname.FQDN][]*dnstype.Resolver{}, + Hosts: hostsFor(node1), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + lb, engine, cc := newLocalBackendWithMockEngineAndControl(t, enableLogging) + + if tt.steps != nil { + tt.steps(t, lb, cc) + } + + if gotState := lb.State(); gotState != tt.wantState { + t.Errorf("State: got %v; want %v", gotState, tt.wantState) + } + + opts := []cmp.Option{ + cmpopts.EquateComparable(key.NodePublic{}, netip.Addr{}, netip.Prefix{}), + } + if diff := cmp.Diff(tt.wantCfg, engine.Config(), opts...); diff != "" { + t.Errorf("wgcfg.Config(+got -want): %v", diff) + } + if diff := cmp.Diff(tt.wantRouterCfg, engine.RouterConfig(), opts...); diff != "" { + t.Errorf("router.Config(+got -want): %v", diff) + } + if diff := cmp.Diff(tt.wantDNSCfg, engine.DNSConfig(), opts...); diff != "" { + t.Errorf("dns.Config(+got -want): %v", diff) + } + }) + } +} + +func testNetmapForNode(userID tailcfg.UserID, name string, addresses []netip.Prefix) *netmap.NetworkMap { + const ( + domain = "example.com" + magicDNSSuffix = ".test.ts.net" + ) + user := &tailcfg.UserProfile{ + ID: userID, + DisplayName: name, + LoginName: strings.Join([]string{name, domain}, "@"), + } + self := &tailcfg.Node{ + ID: tailcfg.NodeID(1000 + userID), + StableID: tailcfg.StableNodeID("stable-" + name), + User: user.ID, + Name: name + magicDNSSuffix, + Addresses: addresses, + MachineAuthorized: true, + } + return &netmap.NetworkMap{ + SelfNode: self.View(), + Name: self.Name, + Domain: domain, + UserProfiles: map[tailcfg.UserID]tailcfg.UserProfileView{ + user.ID: user.View(), + }, + } +} + +func mustDo(t *testing.T) func(error) { + t.Helper() + return func(err error) { + t.Helper() + if err != nil { + t.Fatal(err) + } + } +} + +func mustDo2(t *testing.T) func(any, error) { + t.Helper() + return func(_ any, err error) { + t.Helper() + if err != nil { + t.Fatal(err) + } + } +} + +func newLocalBackendWithMockEngineAndControl(t *testing.T, enableLogging bool) (*LocalBackend, *mockEngine, func() *mockControl) { + t.Helper() + + logf := logger.Discard + if enableLogging { + logf = tstest.WhileTestRunningLogger(t) + } + + dialer := &tsdial.Dialer{Logf: logf} + dialer.SetNetMon(netmon.NewStatic()) + + sys := &tsd.System{} + sys.Set(dialer) + sys.Set(dialer.NetMon()) + + magicConn, err := magicsock.NewConn(magicsock.Options{ + Logf: logf, + NetMon: dialer.NetMon(), + Metrics: sys.UserMetricsRegistry(), + HealthTracker: sys.HealthTracker(), + DisablePortMapper: true, + }) + if err != nil { + t.Fatalf("NewConn failed: %v", err) + } + magicConn.SetNetworkUp(dialer.NetMon().InterfaceState().AnyInterfaceUp()) + sys.Set(magicConn) + + engine := newMockEngine() + sys.Set(engine) + t.Cleanup(func() { + engine.Close() + <-engine.Done() + }) + + lb := newLocalBackendWithSysAndTestControl(t, enableLogging, sys, func(tb testing.TB, opts controlclient.Options) controlclient.Client { + return newClient(tb, opts) + }) + return lb, engine, func() *mockControl { return lb.cc.(*mockControl) } +} + +var _ wgengine.Engine = (*mockEngine)(nil) + +// mockEngine implements [wgengine.Engine]. +type mockEngine struct { + done chan struct{} // closed when Close is called + + mu sync.Mutex // protects all following fields + closed bool + cfg *wgcfg.Config + routerCfg *router.Config + dnsCfg *dns.Config + + filter, jailedFilter *filter.Filter + + statusCb wgengine.StatusCallback +} + +func newMockEngine() *mockEngine { + return &mockEngine{ + done: make(chan struct{}), + } +} + +func (e *mockEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, dnsCfg *dns.Config) error { + e.mu.Lock() + defer e.mu.Unlock() + if e.closed { + return errors.New("engine closed") + } + e.cfg = cfg + e.routerCfg = routerCfg + e.dnsCfg = dnsCfg + return nil +} + +func (e *mockEngine) Config() *wgcfg.Config { + e.mu.Lock() + defer e.mu.Unlock() + return e.cfg +} + +func (e *mockEngine) RouterConfig() *router.Config { + e.mu.Lock() + defer e.mu.Unlock() + return e.routerCfg +} + +func (e *mockEngine) DNSConfig() *dns.Config { + e.mu.Lock() + defer e.mu.Unlock() + return e.dnsCfg +} + +func (e *mockEngine) PeerForIP(netip.Addr) (_ wgengine.PeerForIP, ok bool) { + return wgengine.PeerForIP{}, false +} + +func (e *mockEngine) GetFilter() *filter.Filter { + e.mu.Lock() + defer e.mu.Unlock() + return e.filter +} + +func (e *mockEngine) SetFilter(f *filter.Filter) { + e.mu.Lock() + e.filter = f + e.mu.Unlock() +} + +func (e *mockEngine) GetJailedFilter() *filter.Filter { + e.mu.Lock() + defer e.mu.Unlock() + return e.jailedFilter +} + +func (e *mockEngine) SetJailedFilter(f *filter.Filter) { + e.mu.Lock() + e.jailedFilter = f + e.mu.Unlock() +} + +func (e *mockEngine) SetStatusCallback(cb wgengine.StatusCallback) { + e.mu.Lock() + e.statusCb = cb + e.mu.Unlock() +} + +func (e *mockEngine) RequestStatus() { + e.mu.Lock() + cb := e.statusCb + e.mu.Unlock() + if cb != nil { + cb(&wgengine.Status{AsOf: time.Now()}, nil) + } +} + +func (e *mockEngine) PeerByKey(key.NodePublic) (_ wgint.Peer, ok bool) { + return wgint.Peer{}, false +} + +func (e *mockEngine) SetNetworkMap(*netmap.NetworkMap) {} + +func (e *mockEngine) UpdateStatus(*ipnstate.StatusBuilder) {} + +func (e *mockEngine) Ping(ip netip.Addr, pingType tailcfg.PingType, size int, cb func(*ipnstate.PingResult)) { + cb(&ipnstate.PingResult{IP: ip.String(), Err: "not implemented"}) +} + +func (e *mockEngine) InstallCaptureHook(packet.CaptureCallback) {} + +func (e *mockEngine) Close() { + e.mu.Lock() + defer e.mu.Unlock() + if e.closed { + return + } + e.closed = true + close(e.done) +} + +func (e *mockEngine) Done() <-chan struct{} { + return e.done +}