diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 04d4bbbde..a4ba090ef 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -274,11 +274,9 @@ type Conn struct { captureHook syncs.AtomicValue[packet.CaptureCallback] // hasPeerRelayServers is whether [relayManager] is configured with at least - // one peer relay server via [relayManager.handleRelayServersSet]. It is - // only accessed by [Conn.updateRelayServersSet], [endpoint.setDERPHome], - // and [endpoint.discoverUDPRelayPathsLocked]. It exists to suppress - // calls into [relayManager] leading to wasted work involving channel - // operations and goroutine creation. + // one peer relay server via [relayManager.handleRelayServersSet]. It exists + // to suppress calls into [relayManager] leading to wasted work involving + // channel operations and goroutine creation. hasPeerRelayServers atomic.Bool // discoPrivate is the private naclbox key used for active @@ -2998,6 +2996,7 @@ func (c *Conn) onNodeViewsUpdate(update NodeViewsUpdate) { if peersChanged || relayClientChanged { if !relayClientEnabled { c.relayManager.handleRelayServersSet(nil) + c.hasPeerRelayServers.Store(false) } else { c.updateRelayServersSet(filt, self, peers) } diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index e12f15b22..9399dab32 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -65,7 +65,6 @@ import ( "tailscale.com/types/netmap" "tailscale.com/types/nettype" "tailscale.com/types/ptr" - "tailscale.com/types/views" "tailscale.com/util/cibuild" "tailscale.com/util/clientmetric" "tailscale.com/util/eventbus" @@ -3584,7 +3583,7 @@ func Test_nodeHasCap(t *testing.T) { } } -func TestConn_updateRelayServersSet(t *testing.T) { +func TestConn_onNodeViewsUpdate_updateRelayServersSet(t *testing.T) { peerNodeCandidateRelay := &tailcfg.Node{ Cap: 121, ID: 1, @@ -3618,12 +3617,21 @@ func TestConn_updateRelayServersSet(t *testing.T) { DiscoKey: key.NewDisco().Public(), } + selfNodeNodeAttrDisableRelayClient := selfNode.Clone() + selfNodeNodeAttrDisableRelayClient.CapMap = make(tailcfg.NodeCapMap) + selfNodeNodeAttrDisableRelayClient.CapMap[tailcfg.NodeAttrDisableRelayClient] = nil + + selfNodeNodeAttrOnlyTCP443 := selfNode.Clone() + selfNodeNodeAttrOnlyTCP443.CapMap = make(tailcfg.NodeCapMap) + selfNodeNodeAttrOnlyTCP443.CapMap[tailcfg.NodeAttrOnlyTCP443] = nil + tests := []struct { - name string - filt *filter.Filter - self tailcfg.NodeView - peers views.Slice[tailcfg.NodeView] - wantRelayServers set.Set[candidatePeerRelay] + name string + filt *filter.Filter + self tailcfg.NodeView + peers []tailcfg.NodeView + wantRelayServers set.Set[candidatePeerRelay] + wantRelayClientEnabled bool }{ { name: "candidate relay server", @@ -3639,7 +3647,7 @@ func TestConn_updateRelayServersSet(t *testing.T) { }, }, nil, nil, nil, nil, nil), self: selfNode.View(), - peers: views.SliceOf([]tailcfg.NodeView{peerNodeCandidateRelay.View()}), + peers: []tailcfg.NodeView{peerNodeCandidateRelay.View()}, wantRelayServers: set.SetOf([]candidatePeerRelay{ { nodeKey: peerNodeCandidateRelay.Key, @@ -3647,6 +3655,43 @@ func TestConn_updateRelayServersSet(t *testing.T) { derpHomeRegionID: 1, }, }), + wantRelayClientEnabled: true, + }, + { + name: "no candidate relay server because self has tailcfg.NodeAttrDisableRelayClient", + filt: filter.New([]filtertype.Match{ + { + Srcs: peerNodeCandidateRelay.Addresses, + Caps: []filtertype.CapMatch{ + { + Dst: selfNodeNodeAttrDisableRelayClient.Addresses[0], + Cap: tailcfg.PeerCapabilityRelayTarget, + }, + }, + }, + }, nil, nil, nil, nil, nil), + self: selfNodeNodeAttrDisableRelayClient.View(), + peers: []tailcfg.NodeView{peerNodeCandidateRelay.View()}, + wantRelayServers: make(set.Set[candidatePeerRelay]), + wantRelayClientEnabled: false, + }, + { + name: "no candidate relay server because self has tailcfg.NodeAttrOnlyTCP443", + filt: filter.New([]filtertype.Match{ + { + Srcs: peerNodeCandidateRelay.Addresses, + Caps: []filtertype.CapMatch{ + { + Dst: selfNodeNodeAttrOnlyTCP443.Addresses[0], + Cap: tailcfg.PeerCapabilityRelayTarget, + }, + }, + }, + }, nil, nil, nil, nil, nil), + self: selfNodeNodeAttrOnlyTCP443.View(), + peers: []tailcfg.NodeView{peerNodeCandidateRelay.View()}, + wantRelayServers: make(set.Set[candidatePeerRelay]), + wantRelayClientEnabled: false, }, { name: "self candidate relay server", @@ -3662,7 +3707,7 @@ func TestConn_updateRelayServersSet(t *testing.T) { }, }, nil, nil, nil, nil, nil), self: selfNode.View(), - peers: views.SliceOf([]tailcfg.NodeView{selfNode.View()}), + peers: []tailcfg.NodeView{selfNode.View()}, wantRelayServers: set.SetOf([]candidatePeerRelay{ { nodeKey: selfNode.Key, @@ -3670,6 +3715,7 @@ func TestConn_updateRelayServersSet(t *testing.T) { derpHomeRegionID: 2, }, }), + wantRelayClientEnabled: true, }, { name: "no candidate relay server", @@ -3684,21 +3730,34 @@ func TestConn_updateRelayServersSet(t *testing.T) { }, }, }, nil, nil, nil, nil, nil), - self: selfNode.View(), - peers: views.SliceOf([]tailcfg.NodeView{peerNodeNotCandidateRelayCapVer.View()}), - wantRelayServers: make(set.Set[candidatePeerRelay]), + self: selfNode.View(), + peers: []tailcfg.NodeView{peerNodeNotCandidateRelayCapVer.View()}, + wantRelayServers: make(set.Set[candidatePeerRelay]), + wantRelayClientEnabled: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - c := &Conn{} - c.updateRelayServersSet(tt.filt, tt.self, tt.peers) + c := newConn(t.Logf) + c.filt = tt.filt + if len(tt.wantRelayServers) == 0 { + // So we can verify it gets flipped back. + c.hasPeerRelayServers.Store(true) + } + + c.onNodeViewsUpdate(NodeViewsUpdate{ + SelfNode: tt.self, + Peers: tt.peers, + }) got := c.relayManager.getServers() if !got.Equal(tt.wantRelayServers) { t.Fatalf("got: %v != want: %v", got, tt.wantRelayServers) } if len(tt.wantRelayServers) > 0 != c.hasPeerRelayServers.Load() { - t.Fatalf("c.hasPeerRelayServers: %v != wantRelayServers: %v", c.hasPeerRelayServers.Load(), tt.wantRelayServers) + t.Fatalf("c.hasPeerRelayServers: %v != len(tt.wantRelayServers) > 0: %v", c.hasPeerRelayServers.Load(), len(tt.wantRelayServers) > 0) + } + if c.relayClientEnabled != tt.wantRelayClientEnabled { + t.Fatalf("c.relayClientEnabled: %v != wantRelayClientEnabled: %v", c.relayClientEnabled, tt.wantRelayClientEnabled) } }) }