diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index 70efaddae..3d79618c0 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -80,6 +80,7 @@ import ( "tailscale.com/types/dnstype" "tailscale.com/types/empty" "tailscale.com/types/key" + "tailscale.com/types/lazy" "tailscale.com/types/logger" "tailscale.com/types/logid" "tailscale.com/types/netmap" @@ -6445,7 +6446,7 @@ func (b *LocalBackend) SuggestExitNode() (response apitype.ExitNodeSuggestionRes return last, err } - res, err := suggestExitNode(lastReport, netMap, randomRegion, randomNode) + res, err := suggestExitNode(lastReport, netMap, randomRegion, randomNode, getAllowedSuggestions()) if err != nil { last, err := lastSuggestedExitNode.asAPIType() if err != nil { @@ -6479,22 +6480,34 @@ func (n lastSuggestedExitNode) asAPIType() (res apitype.ExitNodeSuggestionRespon return res, ErrUnableToSuggestLastExitNode } -func suggestExitNode(report *netcheck.Report, netMap *netmap.NetworkMap, selectRegion selectRegionFunc, selectNode selectNodeFunc) (res apitype.ExitNodeSuggestionResponse, err error) { +var getAllowedSuggestions = lazy.SyncFunc(fillAllowedSuggestions) + +func fillAllowedSuggestions() set.Set[tailcfg.StableNodeID] { + nodes, err := syspolicy.GetStringArray(syspolicy.AllowedSuggestedExitNodes, nil) + if err != nil { + log.Printf("fillAllowedSuggestions: unable to look up %q policy: %v", syspolicy.AllowedSuggestedExitNodes, err) + return nil + } + if nodes == nil { + return nil + } + s := make(set.Set[tailcfg.StableNodeID], len(nodes)) + for _, n := range nodes { + s.Add(tailcfg.StableNodeID(n)) + } + return s +} + +func suggestExitNode(report *netcheck.Report, netMap *netmap.NetworkMap, selectRegion selectRegionFunc, selectNode selectNodeFunc, allowList set.Set[tailcfg.StableNodeID]) (res apitype.ExitNodeSuggestionResponse, err error) { if report.PreferredDERP == 0 || netMap == nil || netMap.DERPMap == nil { return res, ErrNoPreferredDERP } - var allowedCandidates set.Set[string] - if allowed, err := syspolicy.GetStringArray(syspolicy.AllowedSuggestedExitNodes, nil); err != nil { - return res, fmt.Errorf("unable to read %s policy: %w", syspolicy.AllowedSuggestedExitNodes, err) - } else if allowed != nil { - allowedCandidates = set.SetOf(allowed) - } candidates := make([]tailcfg.NodeView, 0, len(netMap.Peers)) for _, peer := range netMap.Peers { if !peer.Valid() { continue } - if allowedCandidates != nil && !allowedCandidates.Contains(string(peer.StableID())) { + if allowList != nil && !allowList.Contains(peer.StableID()) { continue } if peer.CapMap().Has(tailcfg.NodeAttrSuggestExitNode) && tsaddr.ContainsExitRoutes(peer.AllowedIPs()) { diff --git a/ipn/ipnlocal/local_test.go b/ipn/ipnlocal/local_test.go index 70da1130e..cb3b5c00a 100644 --- a/ipn/ipnlocal/local_test.go +++ b/ipn/ipnlocal/local_test.go @@ -40,6 +40,7 @@ import ( "tailscale.com/tstest" "tailscale.com/types/dnstype" "tailscale.com/types/key" + "tailscale.com/types/lazy" "tailscale.com/types/logger" "tailscale.com/types/logid" "tailscale.com/types/netmap" @@ -49,6 +50,7 @@ import ( "tailscale.com/util/dnsname" "tailscale.com/util/mak" "tailscale.com/util/must" + "tailscale.com/util/set" "tailscale.com/util/syspolicy" "tailscale.com/wgengine" "tailscale.com/wgengine/filter" @@ -2823,6 +2825,8 @@ func deterministicNodeForTest(t testing.TB, want views.Slice[tailcfg.StableNodeI } func TestSuggestExitNode(t *testing.T) { + t.Parallel() + defaultDERPMap := &tailcfg.DERPMap{ Regions: map[int]*tailcfg.DERPRegion{ 1: { @@ -2963,7 +2967,7 @@ func TestSuggestExitNode(t *testing.T) { netMap *netmap.NetworkMap lastSuggestion tailcfg.StableNodeID - allowPolicy []string + allowPolicy []tailcfg.StableNodeID wantRegions []int useRegion int @@ -3187,13 +3191,13 @@ func TestSuggestExitNode(t *testing.T) { name: "no allowed suggestions", lastReport: preferred1Report, netMap: largeNetmap, - allowPolicy: []string{}, + allowPolicy: []tailcfg.StableNodeID{}, }, { name: "only derp suggestions", lastReport: preferred1Report, netMap: largeNetmap, - allowPolicy: []string{"stable1", "stable2", "stable3"}, + allowPolicy: []tailcfg.StableNodeID{"stable1", "stable2", "stable3"}, wantNodes: []tailcfg.StableNodeID{"stable1", "stable2"}, wantID: "stable2", wantName: "peer2", @@ -3202,7 +3206,7 @@ func TestSuggestExitNode(t *testing.T) { name: "only mullvad suggestions", lastReport: preferred1Report, netMap: largeNetmap, - allowPolicy: []string{"stable5", "stable6", "stable7"}, + allowPolicy: []tailcfg.StableNodeID{"stable5", "stable6", "stable7"}, wantID: "stable7", wantName: "Fort Worth", wantLocation: fortWorth.View(), @@ -3211,7 +3215,7 @@ func TestSuggestExitNode(t *testing.T) { name: "only worst derp", lastReport: preferred1Report, netMap: largeNetmap, - allowPolicy: []string{"stable3"}, + allowPolicy: []tailcfg.StableNodeID{"stable3"}, wantID: "stable3", wantName: "peer3", }, @@ -3219,7 +3223,7 @@ func TestSuggestExitNode(t *testing.T) { name: "only worst mullvad", lastReport: preferred1Report, netMap: largeNetmap, - allowPolicy: []string{"stable6"}, + allowPolicy: []tailcfg.StableNodeID{"stable6"}, wantID: "stable6", wantName: "San Jose", wantLocation: sanJose.View(), @@ -3228,16 +3232,6 @@ func TestSuggestExitNode(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - mh := mockSyspolicyHandler{ - t: t, - } - if tt.allowPolicy != nil { - mh.stringArrayPolicies = map[syspolicy.Key][]string{ - syspolicy.AllowedSuggestedExitNodes: tt.allowPolicy, - } - } - syspolicy.SetHandlerForTest(t, &mh) - wantRegions := tt.wantRegions if wantRegions == nil { wantRegions = []int{tt.useRegion} @@ -3250,7 +3244,12 @@ func TestSuggestExitNode(t *testing.T) { } selectNode := deterministicNodeForTest(t, views.SliceOf(wantNodes), tt.wantID) - got, err := suggestExitNode(tt.lastReport, tt.netMap, selectRegion, selectNode) + var allowList set.Set[tailcfg.StableNodeID] + if tt.allowPolicy != nil { + allowList = set.SetOf(tt.allowPolicy) + } + + got, err := suggestExitNode(tt.lastReport, tt.netMap, selectRegion, selectNode, allowList) if got.Name != tt.wantName { t.Errorf("name=%v, want %v", got.Name, tt.wantName) } @@ -3877,33 +3876,36 @@ func TestLocalBackendSuggestExitNode(t *testing.T) { } for _, tt := range tests { - lb := newTestLocalBackend(t) - msh := &mockSyspolicyHandler{ - t: t, - stringArrayPolicies: map[syspolicy.Key][]string{ - syspolicy.AllowedSuggestedExitNodes: nil, - }, - } - if len(tt.allowedSuggestedExitNodes) != 0 { - msh.stringArrayPolicies[syspolicy.AllowedSuggestedExitNodes] = tt.allowedSuggestedExitNodes - } - syspolicy.SetHandlerForTest(t, msh) - lb.lastSuggestedExitNode = tt.lastSuggestedExitNode - lb.netMap = &tt.netMap - lb.sys.MagicSock.Get().SetLastNetcheckReportForTest(context.Background(), tt.report) - got, err := lb.SuggestExitNode() - if got.ID != tt.wantID { - t.Errorf("ID=%v, want=%v", got.ID, tt.wantID) - } - if got.Name != tt.wantName { - t.Errorf("Name=%v, want=%v", got.Name, tt.wantName) - } - if lb.lastSuggestedExitNode != tt.wantLastSuggestedExitNode { - t.Errorf("lastSuggestedExitNode=%v, want=%v", lb.lastSuggestedExitNode, tt.wantLastSuggestedExitNode) - } - if err != tt.wantErr { - t.Errorf("Error=%v, want=%v", err, tt.wantErr) - } + t.Run(tt.name, func(t *testing.T) { + lb := newTestLocalBackend(t) + msh := &mockSyspolicyHandler{ + t: t, + stringArrayPolicies: map[syspolicy.Key][]string{ + syspolicy.AllowedSuggestedExitNodes: nil, + }, + } + if len(tt.allowedSuggestedExitNodes) != 0 { + msh.stringArrayPolicies[syspolicy.AllowedSuggestedExitNodes] = tt.allowedSuggestedExitNodes + } + syspolicy.SetHandlerForTest(t, msh) + getAllowedSuggestions = lazy.SyncFunc(fillAllowedSuggestions) // clear cache + lb.lastSuggestedExitNode = tt.lastSuggestedExitNode + lb.netMap = &tt.netMap + lb.sys.MagicSock.Get().SetLastNetcheckReportForTest(context.Background(), tt.report) + got, err := lb.SuggestExitNode() + if got.ID != tt.wantID { + t.Errorf("ID=%v, want=%v", got.ID, tt.wantID) + } + if got.Name != tt.wantName { + t.Errorf("Name=%v, want=%v", got.Name, tt.wantName) + } + if lb.lastSuggestedExitNode != tt.wantLastSuggestedExitNode { + t.Errorf("lastSuggestedExitNode=%v, want=%v", lb.lastSuggestedExitNode, tt.wantLastSuggestedExitNode) + } + if err != tt.wantErr { + t.Errorf("Error=%v, want=%v", err, tt.wantErr) + } + }) } } func TestEnableAutoUpdates(t *testing.T) { @@ -4004,3 +4006,63 @@ func TestReadWriteRouteInfo(t *testing.T) { t.Fatalf("read prof2 routeInfo wildcards: want %v, got %v", ri2.Wildcards, readRi.Wildcards) } } + +func TestFillAllowedSuggestions(t *testing.T) { + tests := []struct { + name string + allowPolicy []string + want []tailcfg.StableNodeID + }{ + { + name: "unset", + }, + { + name: "zero", + allowPolicy: []string{}, + want: []tailcfg.StableNodeID{}, + }, + { + name: "one", + allowPolicy: []string{"one"}, + want: []tailcfg.StableNodeID{"one"}, + }, + { + name: "many", + allowPolicy: []string{"one", "two", "three", "four"}, + want: []tailcfg.StableNodeID{"one", "three", "four", "two"}, // order should not matter + }, + { + name: "preserve case", + allowPolicy: []string{"ABC", "def", "gHiJ"}, + want: []tailcfg.StableNodeID{"ABC", "def", "gHiJ"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mh := mockSyspolicyHandler{ + t: t, + } + if tt.allowPolicy != nil { + mh.stringArrayPolicies = map[syspolicy.Key][]string{ + syspolicy.AllowedSuggestedExitNodes: tt.allowPolicy, + } + } + syspolicy.SetHandlerForTest(t, &mh) + + got := fillAllowedSuggestions() + if got == nil { + if tt.want == nil { + return + } + t.Errorf("got nil, want %v", tt.want) + } + if tt.want == nil { + t.Errorf("got %v, want nil", got) + } + + if !got.Equal(set.SetOf(tt.want)) { + t.Errorf("got %v, want %v", got, tt.want) + } + }) + } +}