diff --git a/wgengine/magicsock/relaymanager.go b/wgengine/magicsock/relaymanager.go index fd3f19dfb..2b636dc57 100644 --- a/wgengine/magicsock/relaymanager.go +++ b/wgengine/magicsock/relaymanager.go @@ -112,12 +112,21 @@ type relayEndpointHandshakeWorkDoneEvent struct { latency time.Duration // only relevant if pongReceivedFrom.IsValid() } -// activeWorkRunLoop returns true if there is outstanding allocation or -// handshaking work, otherwise it returns false. -func (r *relayManager) activeWorkRunLoop() bool { +// hasActiveWorkRunLoop returns true if there is outstanding allocation or +// handshaking work for any endpoint, otherwise it returns false. +func (r *relayManager) hasActiveWorkRunLoop() bool { return len(r.allocWorkByEndpoint) > 0 || len(r.handshakeWorkByEndpointByServerDisco) > 0 } +// hasActiveWorkForEndpointRunLoop returns true if there is outstanding +// allocation or handshaking work for the provided endpoint, otherwise it +// returns false. +func (r *relayManager) hasActiveWorkForEndpointRunLoop(ep *endpoint) bool { + _, handshakeWork := r.handshakeWorkByEndpointByServerDisco[ep] + _, allocWork := r.allocWorkByEndpoint[ep] + return handshakeWork || allocWork +} + // runLoop is a form of event loop. It ensures exclusive access to most of // [relayManager] state. func (r *relayManager) runLoop() { @@ -128,9 +137,10 @@ func (r *relayManager) runLoop() { for { select { case ep := <-r.allocateHandshakeCh: - r.stopWorkRunLoop(ep, stopHandshakeWorkOnlyKnownServers) - r.allocateAllServersRunLoop(ep) - if !r.activeWorkRunLoop() { + if !r.hasActiveWorkForEndpointRunLoop(ep) { + r.allocateAllServersRunLoop(ep) + } + if !r.hasActiveWorkRunLoop() { return } case done := <-r.allocateWorkDoneCh: @@ -141,27 +151,27 @@ func (r *relayManager) runLoop() { // overwrite pre-existing keys. delete(r.allocWorkByEndpoint, done.work.ep) } - if !r.activeWorkRunLoop() { + if !r.hasActiveWorkRunLoop() { return } case ep := <-r.cancelWorkCh: - r.stopWorkRunLoop(ep, stopHandshakeWorkAllServers) - if !r.activeWorkRunLoop() { + r.stopWorkRunLoop(ep) + if !r.hasActiveWorkRunLoop() { return } case newServerEndpoint := <-r.newServerEndpointCh: r.handleNewServerEndpointRunLoop(newServerEndpoint) - if !r.activeWorkRunLoop() { + if !r.hasActiveWorkRunLoop() { return } case done := <-r.handshakeWorkDoneCh: r.handleHandshakeWorkDoneRunLoop(done) - if !r.activeWorkRunLoop() { + if !r.hasActiveWorkRunLoop() { return } case discoMsgEvent := <-r.rxHandshakeDiscoMsgCh: r.handleRxHandshakeDiscoMsgRunLoop(discoMsgEvent) - if !r.activeWorkRunLoop() { + if !r.hasActiveWorkRunLoop() { return } } @@ -317,8 +327,8 @@ func relayManagerInputEvent[T any](r *relayManager, ctx context.Context, eventCh } // allocateAndHandshakeAllServers kicks off allocation and handshaking of relay -// endpoints for 'ep' on all known relay servers, canceling any existing -// in-progress work. +// endpoints for 'ep' on all known relay servers if there is no outstanding +// work. func (r *relayManager) allocateAndHandshakeAllServers(ep *endpoint) { relayManagerInputEvent(r, nil, &r.allocateHandshakeCh, ep) } @@ -328,18 +338,9 @@ func (r *relayManager) stopWork(ep *endpoint) { relayManagerInputEvent(r, nil, &r.cancelWorkCh, ep) } -// stopHandshakeWorkFilter represents filters for handshake work cancellation -type stopHandshakeWorkFilter bool - -const ( - stopHandshakeWorkAllServers stopHandshakeWorkFilter = false - stopHandshakeWorkOnlyKnownServers = true -) - // stopWorkRunLoop cancels & clears outstanding allocation and handshaking -// work for 'ep'. Handshake work cancellation is subject to the filter supplied -// in 'f'. -func (r *relayManager) stopWorkRunLoop(ep *endpoint, f stopHandshakeWorkFilter) { +// work for 'ep'. +func (r *relayManager) stopWorkRunLoop(ep *endpoint) { allocWork, ok := r.allocWorkByEndpoint[ep] if ok { allocWork.cancel() @@ -348,13 +349,10 @@ func (r *relayManager) stopWorkRunLoop(ep *endpoint, f stopHandshakeWorkFilter) } byServerDisco, ok := r.handshakeWorkByEndpointByServerDisco[ep] if ok { - for disco, handshakeWork := range byServerDisco { - _, knownServer := r.serversByDisco[disco] - if knownServer || f == stopHandshakeWorkAllServers { - handshakeWork.cancel() - done := <-handshakeWork.doneCh - r.handleHandshakeWorkDoneRunLoop(done) - } + for _, handshakeWork := range byServerDisco { + handshakeWork.cancel() + done := <-handshakeWork.doneCh + r.handleHandshakeWorkDoneRunLoop(done) } } }