tailscale/wgengine/magicsock/relaymanager.go

177 lines
5.1 KiB
Go

// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package magicsock
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/netip"
"sync"
"time"
"tailscale.com/disco"
"tailscale.com/net/udprelay"
"tailscale.com/types/key"
"tailscale.com/util/httpm"
)
// relayManager manages allocation and handshaking of
// [tailscale.com/net/udprelay.Server] endpoints. The zero value is ready for
// use.
type relayManager struct {
mu sync.Mutex // guards the following fields
discoInfoByServerDisco map[key.DiscoPublic]*discoInfo
// serversByAddrPort value is the disco key of the relay server, which is
// discovered at relay endpoint allocation time. Map value will be zero
// (key.DiscoPublic.IsZero()) if no endpoints have been successfully
// allocated on the server, yet.
serversByAddrPort map[netip.AddrPort]key.DiscoPublic
relaySetupWorkByEndpoint map[*endpoint]*relaySetupWork
}
// relaySetupWork serves to track in-progress relay endpoint allocation and
// handshaking work for an [*endpoint]. This structure is immutable once
// initialized.
type relaySetupWork struct {
// ep is the [*endpoint] associated with the work
ep *endpoint
// cancel() will signal all associated goroutines to return
cancel context.CancelFunc
// wg.Wait() will return once all associated goroutines have returned
wg *sync.WaitGroup
}
func (r *relayManager) initLocked() {
if r.discoInfoByServerDisco != nil {
return
}
r.discoInfoByServerDisco = make(map[key.DiscoPublic]*discoInfo)
r.serversByAddrPort = make(map[netip.AddrPort]key.DiscoPublic)
r.relaySetupWorkByEndpoint = make(map[*endpoint]*relaySetupWork)
}
// discoInfo returns a [*discoInfo] for 'serverDisco' if there is an
// active/ongoing handshake with it, otherwise it returns nil, false.
func (r *relayManager) discoInfo(serverDisco key.DiscoPublic) (_ *discoInfo, ok bool) {
r.mu.Lock()
defer r.mu.Unlock()
r.initLocked()
di, ok := r.discoInfoByServerDisco[serverDisco]
return di, ok
}
func (r *relayManager) handleCallMeMaybeVia(dm *disco.CallMeMaybeVia) {
r.mu.Lock()
defer r.mu.Unlock()
r.initLocked()
// TODO(jwhited): implement
}
func (r *relayManager) handleBindUDPRelayEndpointChallenge(dm *disco.BindUDPRelayEndpointChallenge, di *discoInfo, src netip.AddrPort, vni uint32) {
r.mu.Lock()
defer r.mu.Unlock()
r.initLocked()
// TODO(jwhited): implement
}
// cancelOutstandingWork cancels any in-progress work for 'ep'.
func (r *relayManager) cancelOutstandingWork(ep *endpoint) {
r.mu.Lock()
defer r.mu.Unlock()
existing, ok := r.relaySetupWorkByEndpoint[ep]
if ok {
existing.cancel()
existing.wg.Wait()
delete(r.relaySetupWorkByEndpoint, ep)
}
}
// allocateAndHandshakeAllServers kicks off allocation and handshaking of relay
// endpoints for 'ep' on all known relay servers, canceling any existing
// in-progress work.
func (r *relayManager) allocateAndHandshakeAllServers(ep *endpoint) {
r.mu.Lock()
defer r.mu.Unlock()
r.initLocked()
existing, ok := r.relaySetupWorkByEndpoint[ep]
if ok {
existing.cancel()
existing.wg.Wait()
delete(r.relaySetupWorkByEndpoint, ep)
}
if len(r.serversByAddrPort) == 0 {
return
}
ctx, cancel := context.WithCancel(context.Background())
started := &relaySetupWork{ep: ep, cancel: cancel, wg: &sync.WaitGroup{}}
for k := range r.serversByAddrPort {
started.wg.Add(1)
go r.allocateAndHandshakeForServer(ctx, started.wg, k, ep)
}
r.relaySetupWorkByEndpoint[ep] = started
go func() {
started.wg.Wait()
started.cancel()
r.mu.Lock()
defer r.mu.Unlock()
maybeCleanup, ok := r.relaySetupWorkByEndpoint[ep]
if ok && maybeCleanup == started {
// A subsequent call to allocateAndHandshakeAllServers may have raced to
// delete the associated key/value, so ensure the work we are
// cleaning up from the map is the same as the one we were waiting
// to finish.
delete(r.relaySetupWorkByEndpoint, ep)
}
}()
}
func (r *relayManager) handleNewServerEndpoint(ctx context.Context, wg *sync.WaitGroup, server netip.AddrPort, se udprelay.ServerEndpoint) {
// TODO(jwhited): implement
}
func (r *relayManager) allocateAndHandshakeForServer(ctx context.Context, wg *sync.WaitGroup, server netip.AddrPort, ep *endpoint) {
// TODO(jwhited): introduce client metrics counters for notable failures
defer wg.Done()
var b bytes.Buffer
remoteDisco := ep.disco.Load()
if remoteDisco == nil {
return
}
type allocateRelayEndpointReq struct {
DiscoKeys []key.DiscoPublic
}
a := &allocateRelayEndpointReq{
DiscoKeys: []key.DiscoPublic{ep.c.discoPublic, remoteDisco.key},
}
err := json.NewEncoder(&b).Encode(a)
if err != nil {
return
}
const reqTimeout = time.Second * 10
reqCtx, cancel := context.WithTimeout(ctx, reqTimeout)
defer cancel()
req, err := http.NewRequestWithContext(reqCtx, httpm.POST, "http://"+server.String()+"/relay/endpoint", &b)
if err != nil {
return
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return
}
var se udprelay.ServerEndpoint
err = json.NewDecoder(io.LimitReader(resp.Body, 4096)).Decode(&se)
if err != nil {
return
}
r.handleNewServerEndpoint(ctx, wg, server, se)
}