diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index 1f9f7e8b2..8223e86d0 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -5706,13 +5706,15 @@ func (b *LocalBackend) enterStateLockedOnEntry(newState ipn.State, unlock unlock } b.blockEngineUpdates(true) fallthrough - case ipn.Stopped: + case ipn.Stopped, ipn.NoState: + // Unconfigure the engine if it has stopped (WantRunning is set to false) + // or if we've switched to a different profile and the state is unknown. err := b.e.Reconfig(&wgcfg.Config{}, &router.Config{}, &dns.Config{}) if err != nil { b.logf("Reconfig(down): %v", err) } - if authURL == "" { + if newState == ipn.Stopped && authURL == "" { systemd.Status("Stopped; run 'tailscale up' to log in") } case ipn.Starting, ipn.NeedsMachineAuth: @@ -5726,8 +5728,6 @@ func (b *LocalBackend) enterStateLockedOnEntry(newState ipn.State, unlock unlock addrStrs = append(addrStrs, p.Addr().String()) } systemd.Status("Connected; %s; %s", activeLogin, strings.Join(addrStrs, " ")) - case ipn.NoState: - // Do nothing. default: b.logf("[unexpected] unknown newState %#v", newState) } diff --git a/ipn/ipnlocal/local_test.go b/ipn/ipnlocal/local_test.go index 5b74b8180..2579590a8 100644 --- a/ipn/ipnlocal/local_test.go +++ b/ipn/ipnlocal/local_test.go @@ -4407,19 +4407,27 @@ func TestNotificationTargetMatch(t *testing.T) { type newTestControlFn func(tb testing.TB, opts controlclient.Options) controlclient.Client func newLocalBackendWithTestControl(t *testing.T, enableLogging bool, newControl newTestControlFn) *LocalBackend { + return newLocalBackendWithSysAndTestControl(t, enableLogging, new(tsd.System), newControl) +} + +func newLocalBackendWithSysAndTestControl(t *testing.T, enableLogging bool, sys *tsd.System, newControl newTestControlFn) *LocalBackend { logf := logger.Discard if enableLogging { logf = tstest.WhileTestRunningLogger(t) } - sys := new(tsd.System) - store := new(mem.Store) - sys.Set(store) - e, err := wgengine.NewFakeUserspaceEngine(logf, sys.Set, sys.HealthTracker(), sys.UserMetricsRegistry()) - if err != nil { - t.Fatalf("NewFakeUserspaceEngine: %v", err) + + if _, hasStore := sys.StateStore.GetOK(); !hasStore { + store := new(mem.Store) + sys.Set(store) + } + if _, hasEngine := sys.Engine.GetOK(); !hasEngine { + e, err := wgengine.NewFakeUserspaceEngine(logf, sys.Set, sys.HealthTracker(), sys.UserMetricsRegistry()) + if err != nil { + t.Fatalf("NewFakeUserspaceEngine: %v", err) + } + t.Cleanup(e.Close) + sys.Set(e) } - t.Cleanup(e.Close) - sys.Set(e) b, err := NewLocalBackend(logf, logid.PublicID{}, sys, 0) if err != nil { diff --git a/ipn/ipnlocal/state_test.go b/ipn/ipnlocal/state_test.go index a4180de86..3c22b66be 100644 --- a/ipn/ipnlocal/state_test.go +++ b/ipn/ipnlocal/state_test.go @@ -5,26 +5,46 @@ package ipnlocal import ( "context" + "errors" + "net/netip" + "strings" "sync" "sync/atomic" "testing" "time" qt "github.com/frankban/quicktest" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "tailscale.com/control/controlclient" "tailscale.com/envknob" "tailscale.com/ipn" + "tailscale.com/ipn/ipnstate" "tailscale.com/ipn/store/mem" + "tailscale.com/net/dns" + "tailscale.com/net/netmon" + "tailscale.com/net/packet" + "tailscale.com/net/tsdial" "tailscale.com/tailcfg" "tailscale.com/tsd" "tailscale.com/tstest" + "tailscale.com/types/dnstype" "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/logid" "tailscale.com/types/netmap" "tailscale.com/types/persist" + "tailscale.com/types/preftype" + "tailscale.com/util/dnsname" + "tailscale.com/util/mak" + "tailscale.com/util/must" "tailscale.com/wgengine" + "tailscale.com/wgengine/filter" + "tailscale.com/wgengine/magicsock" + "tailscale.com/wgengine/router" + "tailscale.com/wgengine/wgcfg" + "tailscale.com/wgengine/wgint" ) // notifyThrottler receives notifications from an ipn.Backend, blocking @@ -170,6 +190,14 @@ func (cc *mockControl) send(err error, url string, loginFinished bool, nm *netma } } +func (cc *mockControl) authenticated(nm *netmap.NetworkMap) { + if selfUser, ok := nm.UserProfiles[nm.SelfNode.User()]; ok { + cc.persist.UserProfile = *selfUser.AsStruct() + } + cc.persist.NodeID = nm.SelfNode.StableID() + cc.send(nil, "", true, nm) +} + // called records that a particular function name was called. func (cc *mockControl) called(s string) { cc.mu.Lock() @@ -1076,3 +1104,449 @@ func TestWGEngineStatusRace(t *testing.T) { wg.Wait() wantState(ipn.Running) } + +// TestEngineReconfigOnStateChange verifies that wgengine is properly reconfigured +// when the LocalBackend's state changes, such as when the user logs in, switches +// profiles, or disconnects from Tailscale. +func TestEngineReconfigOnStateChange(t *testing.T) { + enableLogging := false + connect := &ipn.MaskedPrefs{Prefs: ipn.Prefs{WantRunning: true}, WantRunningSet: true} + disconnect := &ipn.MaskedPrefs{Prefs: ipn.Prefs{WantRunning: false}, WantRunningSet: true} + node1 := testNetmapForNode(1, "node-1", []netip.Prefix{netip.MustParsePrefix("100.64.1.1/32")}) + node2 := testNetmapForNode(2, "node-2", []netip.Prefix{netip.MustParsePrefix("100.64.1.2/32")}) + routesWithQuad100 := func(extra ...netip.Prefix) []netip.Prefix { + return append(extra, netip.MustParsePrefix("100.100.100.100/32")) + } + hostsFor := func(nm *netmap.NetworkMap) map[dnsname.FQDN][]netip.Addr { + var hosts map[dnsname.FQDN][]netip.Addr + appendNode := func(n tailcfg.NodeView) { + addrs := make([]netip.Addr, 0, n.Addresses().Len()) + for _, addr := range n.Addresses().All() { + addrs = append(addrs, addr.Addr()) + } + mak.Set(&hosts, must.Get(dnsname.ToFQDN(n.Name())), addrs) + } + if nm != nil && nm.SelfNode.Valid() { + appendNode(nm.SelfNode) + } + for _, n := range nm.Peers { + appendNode(n) + } + return hosts + } + + tests := []struct { + name string + steps func(*testing.T, *LocalBackend, func() *mockControl) + wantState ipn.State + wantCfg *wgcfg.Config + wantRouterCfg *router.Config + wantDNSCfg *dns.Config + }{ + { + name: "Initial", + // The configs are nil until the the LocalBackend is started. + wantState: ipn.NoState, + wantCfg: nil, + wantRouterCfg: nil, + wantDNSCfg: nil, + }, + { + name: "Start", + steps: func(t *testing.T, lb *LocalBackend, _ func() *mockControl) { + mustDo(t)(lb.Start(ipn.Options{})) + }, + // Once started, all configs must be reset and have their zero values. + wantState: ipn.NeedsLogin, + wantCfg: &wgcfg.Config{}, + wantRouterCfg: &router.Config{}, + wantDNSCfg: &dns.Config{}, + }, + { + name: "Start/Connect", + steps: func(t *testing.T, lb *LocalBackend, _ func() *mockControl) { + mustDo(t)(lb.Start(ipn.Options{})) + mustDo2(t)(lb.EditPrefs(connect)) + }, + // Same if WantRunning is true, but the auth is not completed yet. + wantState: ipn.NeedsLogin, + wantCfg: &wgcfg.Config{}, + wantRouterCfg: &router.Config{}, + wantDNSCfg: &dns.Config{}, + }, + { + name: "Start/Connect/Login", + steps: func(t *testing.T, lb *LocalBackend, cc func() *mockControl) { + mustDo(t)(lb.Start(ipn.Options{})) + mustDo2(t)(lb.EditPrefs(connect)) + cc().authenticated(node1) + }, + // After the auth is completed, the configs must be updated to reflect the node's netmap. + wantState: ipn.Starting, + wantCfg: &wgcfg.Config{ + Name: "tailscale", + NodeID: node1.SelfNode.StableID(), + Peers: []wgcfg.Peer{}, + Addresses: node1.SelfNode.Addresses().AsSlice(), + }, + wantRouterCfg: &router.Config{ + SNATSubnetRoutes: true, + NetfilterMode: preftype.NetfilterOn, + LocalAddrs: node1.SelfNode.Addresses().AsSlice(), + Routes: routesWithQuad100(), + }, + wantDNSCfg: &dns.Config{ + Routes: map[dnsname.FQDN][]*dnstype.Resolver{}, + Hosts: hostsFor(node1), + }, + }, + { + name: "Start/Connect/Login/Disconnect", + steps: func(t *testing.T, lb *LocalBackend, cc func() *mockControl) { + mustDo(t)(lb.Start(ipn.Options{})) + mustDo2(t)(lb.EditPrefs(connect)) + cc().authenticated(node1) + mustDo2(t)(lb.EditPrefs(disconnect)) + }, + // After disconnecting, all configs must be reset and have their zero values. + wantState: ipn.Stopped, + wantCfg: &wgcfg.Config{}, + wantRouterCfg: &router.Config{}, + wantDNSCfg: &dns.Config{}, + }, + { + name: "Start/Connect/Login/NewProfile", + steps: func(t *testing.T, lb *LocalBackend, cc func() *mockControl) { + mustDo(t)(lb.Start(ipn.Options{})) + mustDo2(t)(lb.EditPrefs(connect)) + cc().authenticated(node1) + mustDo(t)(lb.NewProfile()) + }, + // After switching to a new, empty profile, all configs should be reset + // and have their zero values until the auth is completed. + wantState: ipn.NeedsLogin, + wantCfg: &wgcfg.Config{}, + wantRouterCfg: &router.Config{}, + wantDNSCfg: &dns.Config{}, + }, + { + name: "Start/Connect/Login/NewProfile/Login", + steps: func(t *testing.T, lb *LocalBackend, cc func() *mockControl) { + mustDo(t)(lb.Start(ipn.Options{})) + mustDo2(t)(lb.EditPrefs(connect)) + cc().authenticated(node1) + mustDo(t)(lb.NewProfile()) + mustDo2(t)(lb.EditPrefs(connect)) + cc().authenticated(node2) + }, + // Once the auth is completed, the configs must be updated to reflect the node's netmap. + wantState: ipn.Starting, + wantCfg: &wgcfg.Config{ + Name: "tailscale", + NodeID: node2.SelfNode.StableID(), + Peers: []wgcfg.Peer{}, + Addresses: node2.SelfNode.Addresses().AsSlice(), + }, + wantRouterCfg: &router.Config{ + SNATSubnetRoutes: true, + NetfilterMode: preftype.NetfilterOn, + LocalAddrs: node2.SelfNode.Addresses().AsSlice(), + Routes: routesWithQuad100(), + }, + wantDNSCfg: &dns.Config{ + Routes: map[dnsname.FQDN][]*dnstype.Resolver{}, + Hosts: hostsFor(node2), + }, + }, + { + name: "Start/Connect/Login/SwitchProfile", + steps: func(t *testing.T, lb *LocalBackend, cc func() *mockControl) { + mustDo(t)(lb.Start(ipn.Options{})) + mustDo2(t)(lb.EditPrefs(connect)) + cc().authenticated(node1) + profileID := lb.CurrentProfile().ID() + mustDo(t)(lb.NewProfile()) + cc().authenticated(node2) + mustDo(t)(lb.SwitchProfile(profileID)) + }, + // After switching to an existing profile, all configs must be reset + // and have their zero values until the (non-interactive) login is completed. + wantState: ipn.NoState, + wantCfg: &wgcfg.Config{}, + wantRouterCfg: &router.Config{}, + wantDNSCfg: &dns.Config{}, + }, + { + name: "Start/Connect/Login/SwitchProfile/NonInteractiveLogin", + steps: func(t *testing.T, lb *LocalBackend, cc func() *mockControl) { + mustDo(t)(lb.Start(ipn.Options{})) + mustDo2(t)(lb.EditPrefs(connect)) + cc().authenticated(node1) + profileID := lb.CurrentProfile().ID() + mustDo(t)(lb.NewProfile()) + cc().authenticated(node2) + mustDo(t)(lb.SwitchProfile(profileID)) + cc().authenticated(node1) // complete the login + }, + // After switching profiles and completing the auth, the configs + // must be updated to reflect the node's netmap. + wantState: ipn.Starting, + wantCfg: &wgcfg.Config{ + Name: "tailscale", + NodeID: node1.SelfNode.StableID(), + Peers: []wgcfg.Peer{}, + Addresses: node1.SelfNode.Addresses().AsSlice(), + }, + wantRouterCfg: &router.Config{ + SNATSubnetRoutes: true, + NetfilterMode: preftype.NetfilterOn, + LocalAddrs: node1.SelfNode.Addresses().AsSlice(), + Routes: routesWithQuad100(), + }, + wantDNSCfg: &dns.Config{ + Routes: map[dnsname.FQDN][]*dnstype.Resolver{}, + Hosts: hostsFor(node1), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + lb, engine, cc := newLocalBackendWithMockEngineAndControl(t, enableLogging) + + if tt.steps != nil { + tt.steps(t, lb, cc) + } + + if gotState := lb.State(); gotState != tt.wantState { + t.Errorf("State: got %v; want %v", gotState, tt.wantState) + } + + opts := []cmp.Option{ + cmpopts.EquateComparable(key.NodePublic{}, netip.Addr{}, netip.Prefix{}), + } + if diff := cmp.Diff(tt.wantCfg, engine.Config(), opts...); diff != "" { + t.Errorf("wgcfg.Config(+got -want): %v", diff) + } + if diff := cmp.Diff(tt.wantRouterCfg, engine.RouterConfig(), opts...); diff != "" { + t.Errorf("router.Config(+got -want): %v", diff) + } + if diff := cmp.Diff(tt.wantDNSCfg, engine.DNSConfig(), opts...); diff != "" { + t.Errorf("dns.Config(+got -want): %v", diff) + } + }) + } +} + +func testNetmapForNode(userID tailcfg.UserID, name string, addresses []netip.Prefix) *netmap.NetworkMap { + const ( + domain = "example.com" + magicDNSSuffix = ".test.ts.net" + ) + user := &tailcfg.UserProfile{ + ID: userID, + DisplayName: name, + LoginName: strings.Join([]string{name, domain}, "@"), + } + self := &tailcfg.Node{ + ID: tailcfg.NodeID(1000 + userID), + StableID: tailcfg.StableNodeID("stable-" + name), + User: user.ID, + Name: name + magicDNSSuffix, + Addresses: addresses, + MachineAuthorized: true, + } + return &netmap.NetworkMap{ + SelfNode: self.View(), + Name: self.Name, + Domain: domain, + UserProfiles: map[tailcfg.UserID]tailcfg.UserProfileView{ + user.ID: user.View(), + }, + } +} + +func mustDo(t *testing.T) func(error) { + t.Helper() + return func(err error) { + t.Helper() + if err != nil { + t.Fatal(err) + } + } +} + +func mustDo2(t *testing.T) func(any, error) { + t.Helper() + return func(_ any, err error) { + t.Helper() + if err != nil { + t.Fatal(err) + } + } +} + +func newLocalBackendWithMockEngineAndControl(t *testing.T, enableLogging bool) (*LocalBackend, *mockEngine, func() *mockControl) { + t.Helper() + + logf := logger.Discard + if enableLogging { + logf = tstest.WhileTestRunningLogger(t) + } + + dialer := &tsdial.Dialer{Logf: logf} + dialer.SetNetMon(netmon.NewStatic()) + + sys := &tsd.System{} + sys.Set(dialer) + sys.Set(dialer.NetMon()) + + magicConn, err := magicsock.NewConn(magicsock.Options{ + Logf: logf, + NetMon: dialer.NetMon(), + Metrics: sys.UserMetricsRegistry(), + HealthTracker: sys.HealthTracker(), + DisablePortMapper: true, + }) + if err != nil { + t.Fatalf("NewConn failed: %v", err) + } + magicConn.SetNetworkUp(dialer.NetMon().InterfaceState().AnyInterfaceUp()) + sys.Set(magicConn) + + engine := newMockEngine() + sys.Set(engine) + t.Cleanup(func() { + engine.Close() + <-engine.Done() + }) + + lb := newLocalBackendWithSysAndTestControl(t, enableLogging, sys, func(tb testing.TB, opts controlclient.Options) controlclient.Client { + return newClient(tb, opts) + }) + return lb, engine, func() *mockControl { return lb.cc.(*mockControl) } +} + +var _ wgengine.Engine = (*mockEngine)(nil) + +// mockEngine implements [wgengine.Engine]. +type mockEngine struct { + done chan struct{} // closed when Close is called + + mu sync.Mutex // protects all following fields + closed bool + cfg *wgcfg.Config + routerCfg *router.Config + dnsCfg *dns.Config + + filter, jailedFilter *filter.Filter + + statusCb wgengine.StatusCallback +} + +func newMockEngine() *mockEngine { + return &mockEngine{ + done: make(chan struct{}), + } +} + +func (e *mockEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, dnsCfg *dns.Config) error { + e.mu.Lock() + defer e.mu.Unlock() + if e.closed { + return errors.New("engine closed") + } + e.cfg = cfg + e.routerCfg = routerCfg + e.dnsCfg = dnsCfg + return nil +} + +func (e *mockEngine) Config() *wgcfg.Config { + e.mu.Lock() + defer e.mu.Unlock() + return e.cfg +} + +func (e *mockEngine) RouterConfig() *router.Config { + e.mu.Lock() + defer e.mu.Unlock() + return e.routerCfg +} + +func (e *mockEngine) DNSConfig() *dns.Config { + e.mu.Lock() + defer e.mu.Unlock() + return e.dnsCfg +} + +func (e *mockEngine) PeerForIP(netip.Addr) (_ wgengine.PeerForIP, ok bool) { + return wgengine.PeerForIP{}, false +} + +func (e *mockEngine) GetFilter() *filter.Filter { + e.mu.Lock() + defer e.mu.Unlock() + return e.filter +} + +func (e *mockEngine) SetFilter(f *filter.Filter) { + e.mu.Lock() + e.filter = f + e.mu.Unlock() +} + +func (e *mockEngine) GetJailedFilter() *filter.Filter { + e.mu.Lock() + defer e.mu.Unlock() + return e.jailedFilter +} + +func (e *mockEngine) SetJailedFilter(f *filter.Filter) { + e.mu.Lock() + e.jailedFilter = f + e.mu.Unlock() +} + +func (e *mockEngine) SetStatusCallback(cb wgengine.StatusCallback) { + e.mu.Lock() + e.statusCb = cb + e.mu.Unlock() +} + +func (e *mockEngine) RequestStatus() { + e.mu.Lock() + cb := e.statusCb + e.mu.Unlock() + if cb != nil { + cb(&wgengine.Status{AsOf: time.Now()}, nil) + } +} + +func (e *mockEngine) PeerByKey(key.NodePublic) (_ wgint.Peer, ok bool) { + return wgint.Peer{}, false +} + +func (e *mockEngine) SetNetworkMap(*netmap.NetworkMap) {} + +func (e *mockEngine) UpdateStatus(*ipnstate.StatusBuilder) {} + +func (e *mockEngine) Ping(ip netip.Addr, pingType tailcfg.PingType, size int, cb func(*ipnstate.PingResult)) { + cb(&ipnstate.PingResult{IP: ip.String(), Err: "not implemented"}) +} + +func (e *mockEngine) InstallCaptureHook(packet.CaptureCallback) {} + +func (e *mockEngine) Close() { + e.mu.Lock() + defer e.mu.Unlock() + if e.closed { + return + } + e.closed = true + close(e.done) +} + +func (e *mockEngine) Done() <-chan struct{} { + return e.done +}