diff --git a/wgengine/magicsock/endpoint.go b/wgengine/magicsock/endpoint.go index f88dab29d..e834c277c 100644 --- a/wgengine/magicsock/endpoint.go +++ b/wgengine/magicsock/endpoint.go @@ -1871,7 +1871,7 @@ func (de *endpoint) resetLocked() { } } de.probeUDPLifetime.resetCycleEndpointLocked() - de.c.relayManager.cancelOutstandingWork(de) + de.c.relayManager.stopWork(de) } func (de *endpoint) numStopAndReset() int64 { diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index cf3ef2352..05f4cf56d 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -1960,7 +1960,7 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc ke c.discoShort, epDisco.short, via.ServerDisco.ShortString(), ep.publicKey.ShortString(), derpStr(src.String()), len(via.AddrPorts)) - c.relayManager.handleCallMeMaybeVia(via) + c.relayManager.handleCallMeMaybeVia(ep, via) } else { c.dlogf("[v1] magicsock: disco: %v<-%v (%v, %v) got call-me-maybe, %d endpoints", c.discoShort, epDisco.short, diff --git a/wgengine/magicsock/relaymanager.go b/wgengine/magicsock/relaymanager.go index b1732ff41..a66d80e35 100644 --- a/wgengine/magicsock/relaymanager.go +++ b/wgengine/magicsock/relaymanager.go @@ -28,21 +28,23 @@ type relayManager struct { // =================================================================== // The following fields are owned by a single goroutine, runLoop(). - serversByAddrPort set.Set[netip.AddrPort] - allocWorkByEndpoint map[*endpoint]*relayEndpointAllocWork + serversByAddrPort map[netip.AddrPort]key.DiscoPublic + serversByDisco map[key.DiscoPublic]netip.AddrPort + allocWorkByEndpoint map[*endpoint]*relayEndpointAllocWork + handshakeWorkByEndpointByServerDisco map[*endpoint]map[key.DiscoPublic]*relayHandshakeWork // =================================================================== // The following chan fields serve event inputs to a single goroutine, // runLoop(). allocateHandshakeCh chan *endpoint allocateWorkDoneCh chan relayEndpointAllocWorkDoneEvent + handshakeWorkDoneCh chan relayEndpointHandshakeWorkDoneEvent cancelWorkCh chan *endpoint newServerEndpointCh chan newRelayServerEndpointEvent rxChallengeCh chan relayHandshakeChallengeEvent - rxCallMeMaybeViaCh chan *disco.CallMeMaybeVia discoInfoMu sync.Mutex // guards the following field - discoInfoByServerDisco map[key.DiscoPublic]*discoInfo + discoInfoByServerDisco map[key.DiscoPublic]*relayHandshakeDiscoInfo // runLoopStoppedCh is written to by runLoop() upon return, enabling event // writers to restart it when they are blocked (see @@ -50,21 +52,44 @@ type relayManager struct { runLoopStoppedCh chan struct{} } -type newRelayServerEndpointEvent struct { - ep *endpoint - se udprelay.ServerEndpoint +// relayHandshakeWork serves to track in-progress relay handshake work for a +// [udprelay.ServerEndpoint]. This structure is immutable once initialized. +type relayHandshakeWork struct { + ep *endpoint + se udprelay.ServerEndpoint + cancel context.CancelFunc + wg *sync.WaitGroup } +// newRelayServerEndpointEvent indicates a new [udprelay.ServerEndpoint] has +// become known either via allocation with a relay server, or via +// [disco.CallMeMaybeVia] reception. This structure is immutable once +// initialized. +type newRelayServerEndpointEvent struct { + ep *endpoint + se udprelay.ServerEndpoint + server netip.AddrPort // zero value if learned via [disco.CallMeMaybeVia] +} + +// relayEndpointAllocWorkDoneEvent indicates relay server endpoint allocation +// work for an [*endpoint] has completed. This structure is immutable once +// initialized. type relayEndpointAllocWorkDoneEvent struct { - ep *endpoint work *relayEndpointAllocWork } -// activeWork returns true if there is outstanding allocation or handshaking -// work, otherwise it returns false. -func (r *relayManager) activeWork() bool { - return len(r.allocWorkByEndpoint) > 0 - // TODO(jwhited): consider handshaking work +// relayEndpointHandshakeWorkDoneEvent indicates relay server endpoint handshake +// work for an [*endpoint] has completed. This structure is immutable once +// initialized. +type relayEndpointHandshakeWorkDoneEvent struct { + work *relayHandshakeWork + answerSentTo netip.AddrPort // zero value if handshake did not progress to answer transmission +} + +// activeWorkRunLoop returns true if there is outstanding allocation or +// handshaking work, otherwise it returns false. +func (r *relayManager) activeWorkRunLoop() bool { + return len(r.allocWorkByEndpoint) > 0 || len(r.handshakeWorkByEndpointByServerDisco) > 0 } // runLoop is a form of event loop. It ensures exclusive access to most of @@ -77,43 +102,41 @@ func (r *relayManager) runLoop() { for { select { case ep := <-r.allocateHandshakeCh: - r.cancelAndClearWork(ep) - r.allocateAllServersForEndpoint(ep) - if !r.activeWork() { + r.stopWorkRunLoop(ep, stopHandshakeWorkOnlyKnownServers) + r.allocateAllServersRunLoop(ep) + if !r.activeWorkRunLoop() { return } - case msg := <-r.allocateWorkDoneCh: - work, ok := r.allocWorkByEndpoint[msg.ep] - if ok && work == msg.work { + case done := <-r.allocateWorkDoneCh: + work, ok := r.allocWorkByEndpoint[done.work.ep] + if ok && work == done.work { // Verify the work in the map is the same as the one that we're // cleaning up. New events on r.allocateHandshakeCh can // overwrite pre-existing keys. - delete(r.allocWorkByEndpoint, msg.ep) + delete(r.allocWorkByEndpoint, done.work.ep) } - if !r.activeWork() { + if !r.activeWorkRunLoop() { return } case ep := <-r.cancelWorkCh: - r.cancelAndClearWork(ep) - if !r.activeWork() { + r.stopWorkRunLoop(ep, stopHandshakeWorkAllServers) + if !r.activeWorkRunLoop() { return } - case newEndpoint := <-r.newServerEndpointCh: - _ = newEndpoint - // TODO(jwhited): implement - if !r.activeWork() { + case newServerEndpoint := <-r.newServerEndpointCh: + r.handleNewServerEndpointRunLoop(newServerEndpoint) + if !r.activeWorkRunLoop() { + return + } + case done := <-r.handshakeWorkDoneCh: + r.handleHandshakeWorkDoneRunLoop(done) + if !r.activeWorkRunLoop() { return } case challenge := <-r.rxChallengeCh: _ = challenge // TODO(jwhited): implement - if !r.activeWork() { - return - } - case via := <-r.rxCallMeMaybeViaCh: - _ = via - // TODO(jwhited): implement - if !r.activeWork() { + if !r.activeWorkRunLoop() { return } } @@ -142,30 +165,92 @@ type relayEndpointAllocWork struct { // init initializes [relayManager] if it is not already initialized. func (r *relayManager) init() { r.initOnce.Do(func() { - r.discoInfoByServerDisco = make(map[key.DiscoPublic]*discoInfo) + r.discoInfoByServerDisco = make(map[key.DiscoPublic]*relayHandshakeDiscoInfo) + r.serversByDisco = make(map[key.DiscoPublic]netip.AddrPort) + r.serversByAddrPort = make(map[netip.AddrPort]key.DiscoPublic) r.allocWorkByEndpoint = make(map[*endpoint]*relayEndpointAllocWork) + r.handshakeWorkByEndpointByServerDisco = make(map[*endpoint]map[key.DiscoPublic]*relayHandshakeWork) r.allocateHandshakeCh = make(chan *endpoint) r.allocateWorkDoneCh = make(chan relayEndpointAllocWorkDoneEvent) + r.handshakeWorkDoneCh = make(chan relayEndpointHandshakeWorkDoneEvent) r.cancelWorkCh = make(chan *endpoint) r.newServerEndpointCh = make(chan newRelayServerEndpointEvent) r.rxChallengeCh = make(chan relayHandshakeChallengeEvent) - r.rxCallMeMaybeViaCh = make(chan *disco.CallMeMaybeVia) r.runLoopStoppedCh = make(chan struct{}, 1) go r.runLoop() }) } +// relayHandshakeDiscoInfo serves to cache a [*discoInfo] for outstanding +// [*relayHandshakeWork] against a given relay server. +type relayHandshakeDiscoInfo struct { + work set.Set[*relayHandshakeWork] // guarded by relayManager.discoInfoMu + di *discoInfo // immutable once initialized +} + +// ensureDiscoInfoFor ensures a [*discoInfo] will be returned by discoInfo() for +// the server disco key associated with 'work'. Callers must also call +// derefDiscoInfoFor() when 'work' is complete. +func (r *relayManager) ensureDiscoInfoFor(work *relayHandshakeWork) { + r.discoInfoMu.Lock() + defer r.discoInfoMu.Unlock() + di, ok := r.discoInfoByServerDisco[work.se.ServerDisco] + if !ok { + di = &relayHandshakeDiscoInfo{} + di.work.Make() + r.discoInfoByServerDisco[work.se.ServerDisco] = di + } + di.work.Add(work) + if di.di == nil { + di.di = &discoInfo{ + discoKey: work.se.ServerDisco, + discoShort: work.se.ServerDisco.ShortString(), + sharedKey: work.ep.c.discoPrivate.Shared(work.se.ServerDisco), + } + } +} + +// derefDiscoInfoFor decrements the reference count of the [*discoInfo] +// associated with 'work'. +func (r *relayManager) derefDiscoInfoFor(work *relayHandshakeWork) { + r.discoInfoMu.Lock() + defer r.discoInfoMu.Unlock() + di, ok := r.discoInfoByServerDisco[work.se.ServerDisco] + if !ok { + // TODO(jwhited): unexpected + return + } + di.work.Delete(work) + if di.work.Len() == 0 { + delete(r.discoInfoByServerDisco, work.se.ServerDisco) + } +} + // discoInfo returns a [*discoInfo] for 'serverDisco' if there is an // active/ongoing handshake with it, otherwise it returns nil, false. func (r *relayManager) discoInfo(serverDisco key.DiscoPublic) (_ *discoInfo, ok bool) { r.discoInfoMu.Lock() defer r.discoInfoMu.Unlock() di, ok := r.discoInfoByServerDisco[serverDisco] - return di, ok + if ok { + return di.di, ok + } + return nil, false } -func (r *relayManager) handleCallMeMaybeVia(dm *disco.CallMeMaybeVia) { - relayManagerInputEvent(r, nil, &r.rxCallMeMaybeViaCh, dm) +func (r *relayManager) handleCallMeMaybeVia(ep *endpoint, dm *disco.CallMeMaybeVia) { + se := udprelay.ServerEndpoint{ + ServerDisco: dm.ServerDisco, + LamportID: dm.LamportID, + AddrPorts: dm.AddrPorts, + VNI: dm.VNI, + } + se.BindLifetime.Duration = dm.BindLifetime + se.SteadyStateLifetime.Duration = dm.SteadyStateLifetime + relayManagerInputEvent(r, nil, &r.newServerEndpointCh, newRelayServerEndpointEvent{ + ep: ep, + se: se, + }) } func (r *relayManager) handleBindUDPRelayEndpointChallenge(dm *disco.BindUDPRelayEndpointChallenge, di *discoInfo, src netip.AddrPort, vni uint32) { @@ -206,24 +291,142 @@ func (r *relayManager) allocateAndHandshakeAllServers(ep *endpoint) { relayManagerInputEvent(r, nil, &r.allocateHandshakeCh, ep) } -// cancelOutstandingWork cancels all outstanding allocation & handshaking work -// for 'ep'. -func (r *relayManager) cancelOutstandingWork(ep *endpoint) { +// stopWork stops all outstanding allocation & handshaking work for 'ep'. +func (r *relayManager) stopWork(ep *endpoint) { relayManagerInputEvent(r, nil, &r.cancelWorkCh, ep) } -// cancelAndClearWork cancels & clears any outstanding work for 'ep'. -func (r *relayManager) cancelAndClearWork(ep *endpoint) { +// stopHandshakeWorkFilter represents filters for handshake work cancellation +type stopHandshakeWorkFilter bool + +const ( + stopHandshakeWorkAllServers stopHandshakeWorkFilter = false + stopHandshakeWorkOnlyKnownServers = true +) + +// stopWorkRunLoop cancels & clears outstanding allocation and handshaking +// work for 'ep'. Handshake work cancellation is subject to the filter supplied +// in 'f'. +func (r *relayManager) stopWorkRunLoop(ep *endpoint, f stopHandshakeWorkFilter) { allocWork, ok := r.allocWorkByEndpoint[ep] if ok { allocWork.cancel() allocWork.wg.Wait() delete(r.allocWorkByEndpoint, ep) } - // TODO(jwhited): cancel & clear handshake work + byServerDisco, ok := r.handshakeWorkByEndpointByServerDisco[ep] + if ok { + for disco, handshakeWork := range byServerDisco { + _, knownServer := r.serversByDisco[disco] + if knownServer || f == stopHandshakeWorkAllServers { + handshakeWork.cancel() + handshakeWork.wg.Wait() + delete(byServerDisco, disco) + } + } + if len(byServerDisco) == 0 { + delete(r.handshakeWorkByEndpointByServerDisco, ep) + } + } } -func (r *relayManager) allocateAllServersForEndpoint(ep *endpoint) { +func (r *relayManager) handleHandshakeWorkDoneRunLoop(done relayEndpointHandshakeWorkDoneEvent) { + byServerDisco, ok := r.handshakeWorkByEndpointByServerDisco[done.work.ep] + if !ok { + return + } + work, ok := byServerDisco[done.work.se.ServerDisco] + if !ok || work != done.work { + return + } + delete(byServerDisco, done.work.se.ServerDisco) + if len(byServerDisco) == 0 { + delete(r.handshakeWorkByEndpointByServerDisco, done.work.ep) + } + if !done.answerSentTo.IsValid() { + // The handshake timed out. + return + } + // We received a challenge from and transmitted an answer towards the relay + // server. + // TODO(jwhited): Make the associated [*endpoint] aware of this + // [tailscale.com/net/udprelay.ServerEndpoint]. +} + +func (r *relayManager) handleNewServerEndpointRunLoop(newServerEndpoint newRelayServerEndpointEvent) { + if newServerEndpoint.server.IsValid() { + serverDisco, ok := r.serversByAddrPort[newServerEndpoint.server] + if !ok { + // Allocation raced with an update to our known servers set. This + // server is no longer known. Return early. + return + } + if serverDisco.Compare(newServerEndpoint.se.ServerDisco) != 0 { + // The server's disco key has either changed, or simply become + // known for the first time. In the former case we end up detaching + // any in-progress handshake work from a "known" relay server. + // Practically speaking we expect the detached work to fail + // if the server key did in fact change (server restart) while we + // were attempting to handshake with it. It is possible, though + // unlikely, for a server addr:port to effectively move between + // nodes. Either way, there is no harm in detaching existing work, + // and we explicitly let that happen for the rare case the detached + // handshake would complete and remain functional. + delete(r.serversByDisco, serverDisco) + delete(r.serversByAddrPort, newServerEndpoint.server) + r.serversByDisco[serverDisco] = newServerEndpoint.server + r.serversByAddrPort[newServerEndpoint.server] = serverDisco + } + } + + byServerDisco, ok := r.handshakeWorkByEndpointByServerDisco[newServerEndpoint.ep] + if ok { + work, ok := byServerDisco[newServerEndpoint.se.ServerDisco] + if ok { + if newServerEndpoint.se.LamportID <= work.se.LamportID { + // The "new" server endpoint is outdated or duplicate in + // consideration against existing handshake work. Return early. + return + } + // Cancel existing handshake that has a lower lamport ID. + work.cancel() + work.wg.Wait() + } + } else { + byServerDisco = make(map[key.DiscoPublic]*relayHandshakeWork) + r.handshakeWorkByEndpointByServerDisco[newServerEndpoint.ep] = byServerDisco + } + + // We're ready to start a new handshake. + ctx, cancel := context.WithCancel(context.Background()) + wg := &sync.WaitGroup{} + work := &relayHandshakeWork{ + ep: newServerEndpoint.ep, + se: newServerEndpoint.se, + cancel: cancel, + wg: wg, + } + byServerDisco[newServerEndpoint.se.ServerDisco] = work + + wg.Add(1) + go r.handshakeServerEndpoint(ctx, work) +} + +func (r *relayManager) handshakeServerEndpoint(ctx context.Context, work *relayHandshakeWork) { + defer work.wg.Done() + + done := relayEndpointHandshakeWorkDoneEvent{work: work} + r.ensureDiscoInfoFor(work) + + defer func() { + r.derefDiscoInfoFor(work) + relayManagerInputEvent(r, ctx, &r.handshakeWorkDoneCh, done) + }() + + // TODO(jwhited): implement handshake select +} + +func (r *relayManager) allocateAllServersRunLoop(ep *endpoint) { if len(r.serversByAddrPort) == 0 { return } @@ -231,17 +434,17 @@ func (r *relayManager) allocateAllServersForEndpoint(ep *endpoint) { started := &relayEndpointAllocWork{ep: ep, cancel: cancel, wg: &sync.WaitGroup{}} for k := range r.serversByAddrPort { started.wg.Add(1) - go r.allocateEndpoint(ctx, started.wg, k, ep) + go r.allocateSingleServer(ctx, started.wg, k, ep) } r.allocWorkByEndpoint[ep] = started go func() { started.wg.Wait() started.cancel() - relayManagerInputEvent(r, ctx, &r.allocateWorkDoneCh, relayEndpointAllocWorkDoneEvent{ep: ep, work: started}) + relayManagerInputEvent(r, ctx, &r.allocateWorkDoneCh, relayEndpointAllocWorkDoneEvent{work: started}) }() } -func (r *relayManager) allocateEndpoint(ctx context.Context, wg *sync.WaitGroup, server netip.AddrPort, ep *endpoint) { +func (r *relayManager) allocateSingleServer(ctx context.Context, wg *sync.WaitGroup, server netip.AddrPort, ep *endpoint) { // TODO(jwhited): introduce client metrics counters for notable failures defer wg.Done() var b bytes.Buffer diff --git a/wgengine/magicsock/relaymanager_test.go b/wgengine/magicsock/relaymanager_test.go index 579dceb53..3b75db9f6 100644 --- a/wgengine/magicsock/relaymanager_test.go +++ b/wgengine/magicsock/relaymanager_test.go @@ -8,6 +8,7 @@ import ( "testing" "tailscale.com/disco" + "tailscale.com/types/key" ) func TestRelayManagerInitAndIdle(t *testing.T) { @@ -16,11 +17,11 @@ func TestRelayManagerInitAndIdle(t *testing.T) { <-rm.runLoopStoppedCh rm = relayManager{} - rm.cancelOutstandingWork(&endpoint{}) + rm.stopWork(&endpoint{}) <-rm.runLoopStoppedCh rm = relayManager{} - rm.handleCallMeMaybeVia(&disco.CallMeMaybeVia{}) + rm.handleCallMeMaybeVia(&endpoint{c: &Conn{discoPrivate: key.NewDisco()}}, &disco.CallMeMaybeVia{ServerDisco: key.NewDisco().Public()}) <-rm.runLoopStoppedCh rm = relayManager{}