diff --git a/wgengine/magicsock/relaymanager.go b/wgengine/magicsock/relaymanager.go index 4680832d9..a9dca70ae 100644 --- a/wgengine/magicsock/relaymanager.go +++ b/wgengine/magicsock/relaymanager.go @@ -758,7 +758,10 @@ func (r *relayManager) handleNewServerEndpointRunLoop(newServerEndpoint newRelay ctx: ctx, cancel: cancel, } - if byServerDisco == nil { + // We must look up byServerDisco again. The previous value may have been + // deleted from the outer map when cleaning up duplicate work. + byServerDisco, ok = r.handshakeWorkByServerDiscoByEndpoint[newServerEndpoint.wlb.ep] + if !ok { byServerDisco = make(map[key.DiscoPublic]*relayHandshakeWork) r.handshakeWorkByServerDiscoByEndpoint[newServerEndpoint.wlb.ep] = byServerDisco } diff --git a/wgengine/magicsock/relaymanager_test.go b/wgengine/magicsock/relaymanager_test.go index e4891f567..6ae21b8fb 100644 --- a/wgengine/magicsock/relaymanager_test.go +++ b/wgengine/magicsock/relaymanager_test.go @@ -7,6 +7,7 @@ import ( "testing" "tailscale.com/disco" + udprelay "tailscale.com/net/udprelay/endpoint" "tailscale.com/types/key" "tailscale.com/util/set" ) @@ -78,3 +79,41 @@ func TestRelayManagerGetServers(t *testing.T) { t.Errorf("got %v != want %v", got, servers) } } + +// Test for http://go/corp/32978 +func TestRelayManager_handleNewServerEndpointRunLoop(t *testing.T) { + rm := relayManager{} + rm.init() + <-rm.runLoopStoppedCh // prevent runLoop() from starting, we will inject/handle events in the test + ep := &endpoint{} + conn := newConn(t.Logf) + ep.c = conn + serverDisco := key.NewDisco().Public() + rm.handleNewServerEndpointRunLoop(newRelayServerEndpointEvent{ + wlb: endpointWithLastBest{ + ep: ep, + }, + se: udprelay.ServerEndpoint{ + ServerDisco: serverDisco, + LamportID: 1, + VNI: 1, + }, + }) + rm.handleNewServerEndpointRunLoop(newRelayServerEndpointEvent{ + wlb: endpointWithLastBest{ + ep: ep, + }, + se: udprelay.ServerEndpoint{ + ServerDisco: serverDisco, + LamportID: 2, + VNI: 2, + }, + }) + rm.stopWorkRunLoop(ep) + if len(rm.handshakeWorkByServerDiscoByEndpoint) != 0 || + len(rm.handshakeWorkByServerDiscoVNI) != 0 || + len(rm.handshakeWorkAwaitingPong) != 0 || + len(rm.addrPortVNIToHandshakeWork) != 0 { + t.Fatal("stranded relayHandshakeWork state") + } +}