tailscale/net/udprelay/server.go

533 lines
16 KiB
Go
Raw Normal View History

// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Package udprelay contains constructs for relaying Disco and WireGuard packets
// between Tailscale clients over UDP. This package is currently considered
// experimental.
package udprelay
import (
"bytes"
"crypto/rand"
"errors"
"fmt"
"net"
"net/netip"
"slices"
"strconv"
"sync"
"time"
"go4.org/mem"
"tailscale.com/disco"
"tailscale.com/net/packet"
"tailscale.com/types/key"
)
const (
// defaultBindLifetime is somewhat arbitrary. We attempt to account for
// high latency between client and Server, and high latency between
// clients over side channels, e.g. DERP, used to exchange ServerEndpoint
// details. So, a total of 3 paths with potentially high latency. Using a
// conservative 10s "high latency" bounds for each path we end up at a 30s
// total. It is worse to set an aggressive bind lifetime as this may lead
// to path discovery failure, vs dealing with a slight increase of Server
// resource utilization (VNIs, RAM, etc) while tracking endpoints that won't
// bind.
defaultBindLifetime = time.Second * 30
defaultSteadyStateLifetime = time.Minute * 5
)
// Server implements an experimental UDP relay server.
type Server struct {
// disco keypair used as part of 3-way bind handshake
disco key.DiscoPrivate
discoPublic key.DiscoPublic
bindLifetime time.Duration
steadyStateLifetime time.Duration
// addrPorts contains the ip:port pairs returned as candidate server
// endpoints in response to an allocation request.
addrPorts []netip.AddrPort
uc *net.UDPConn
closeOnce sync.Once
wg sync.WaitGroup
closeCh chan struct{}
closed bool
mu sync.Mutex // guards the following fields
lamportID uint64
vniPool []uint32 // the pool of available VNIs
byVNI map[uint32]*serverEndpoint
byDisco map[pairOfDiscoPubKeys]*serverEndpoint
}
// pairOfDiscoPubKeys is a pair of key.DiscoPublic. It must be constructed via
// newPairOfDiscoPubKeys to ensure lexicographical ordering.
type pairOfDiscoPubKeys [2]key.DiscoPublic
func (p pairOfDiscoPubKeys) String() string {
return fmt.Sprintf("%s <=> %s", p[0].ShortString(), p[1].ShortString())
}
func newPairOfDiscoPubKeys(discoA, discoB key.DiscoPublic) pairOfDiscoPubKeys {
pair := pairOfDiscoPubKeys([2]key.DiscoPublic{discoA, discoB})
slices.SortFunc(pair[:], func(a, b key.DiscoPublic) int {
return a.Compare(b)
})
return pair
}
// ServerEndpoint contains the Server's endpoint details.
type ServerEndpoint struct {
// ServerDisco is the Server's Disco public key used as part of the 3-way
// bind handshake. Server will use the same ServerDisco for its lifetime.
// ServerDisco value in combination with LamportID value represents a
// unique ServerEndpoint allocation.
ServerDisco key.DiscoPublic
// LamportID is unique and monotonically non-decreasing across
// ServerEndpoint allocations for the lifetime of Server. It enables clients
// to dedup and resolve allocation event order. Clients may race to allocate
// on the same Server, and signal ServerEndpoint details via alternative
// channels, e.g. DERP. Additionally, Server.AllocateEndpoint() requests may
// not result in a new allocation depending on existing server-side endpoint
// state. Therefore, where clients have local, existing state that contains
// ServerDisco and LamportID values matching a newly learned endpoint, these
// can be considered one and the same. If ServerDisco is equal, but
// LamportID is unequal, LamportID comparison determines which
// ServerEndpoint was allocated most recently.
LamportID uint64
// AddrPorts are the IP:Port candidate pairs the Server may be reachable
// over.
AddrPorts []netip.AddrPort
// VNI (Virtual Network Identifier) is the Geneve header VNI the Server
// will use for transmitted packets, and expects for received packets
// associated with this endpoint.
VNI uint32
// BindLifetime is amount of time post-allocation the Server will consider
// the endpoint active while it has yet to be bound via 3-way bind handshake
// from both client parties.
BindLifetime time.Duration
// SteadyStateLifetime is the amount of time post 3-way bind handshake from
// both client parties the Server will consider the endpoint active lacking
// bidirectional data flow.
SteadyStateLifetime time.Duration
}
// serverEndpoint contains Server-internal ServerEndpoint state. serverEndpoint
// methods are not thread-safe.
type serverEndpoint struct {
// discoPubKeys contains the key.DiscoPublic of the served clients. The
// indexing of this array aligns with the following fields, e.g.
// discoSharedSecrets[0] is the shared secret to use when sealing
// Disco protocol messages for transmission towards discoPubKeys[0].
discoPubKeys pairOfDiscoPubKeys
discoSharedSecrets [2]key.DiscoShared
handshakeState [2]disco.BindUDPRelayHandshakeState
addrPorts [2]netip.AddrPort
lastSeen [2]time.Time // TODO(jwhited): consider using mono.Time
challenge [2][disco.BindUDPRelayEndpointChallengeLen]byte
lamportID uint64
vni uint32
allocatedAt time.Time
}
func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex int, discoMsg disco.Message, uw udpWriter, serverDisco key.DiscoPublic) {
if senderIndex != 0 && senderIndex != 1 {
return
}
handshakeState := e.handshakeState[senderIndex]
if handshakeState == disco.BindUDPRelayHandshakeStateAnswerReceived {
// this sender is already bound
return
}
switch discoMsg := discoMsg.(type) {
case *disco.BindUDPRelayEndpoint:
switch handshakeState {
case disco.BindUDPRelayHandshakeStateInit:
// set sender addr
e.addrPorts[senderIndex] = from
fallthrough
case disco.BindUDPRelayHandshakeStateChallengeSent:
if from != e.addrPorts[senderIndex] {
// this is a later arriving bind from a different source, or
// a retransmit and the sender's source has changed, discard
return
}
m := new(disco.BindUDPRelayEndpointChallenge)
copy(m.Challenge[:], e.challenge[senderIndex][:])
reply := make([]byte, packet.GeneveFixedHeaderLength, 512)
gh := packet.GeneveHeader{Control: true, VNI: e.vni, Protocol: packet.GeneveProtocolDisco}
err := gh.Encode(reply)
if err != nil {
return
}
reply = append(reply, disco.Magic...)
reply = serverDisco.AppendTo(reply)
box := e.discoSharedSecrets[senderIndex].Seal(m.AppendMarshal(nil))
reply = append(reply, box...)
uw.WriteMsgUDPAddrPort(reply, nil, from)
// set new state
e.handshakeState[senderIndex] = disco.BindUDPRelayHandshakeStateChallengeSent
return
default:
// disco.BindUDPRelayEndpoint is unexpected in all other handshake states
return
}
case *disco.BindUDPRelayEndpointAnswer:
switch handshakeState {
case disco.BindUDPRelayHandshakeStateChallengeSent:
if from != e.addrPorts[senderIndex] {
// sender source has changed
return
}
if !bytes.Equal(discoMsg.Answer[:], e.challenge[senderIndex][:]) {
// bad answer
return
}
// sender is now bound
// TODO: Consider installing a fast path via netfilter or similar to
// relay (NAT) data packets for this serverEndpoint.
e.handshakeState[senderIndex] = disco.BindUDPRelayHandshakeStateAnswerReceived
// record last seen as bound time
e.lastSeen[senderIndex] = time.Now()
return
default:
// disco.BindUDPRelayEndpointAnswer is unexpected in all other handshake
// states, or we've already handled it
return
}
default:
// unexpected Disco message type
return
}
}
func (e *serverEndpoint) handleSealedDiscoControlMsg(from netip.AddrPort, b []byte, uw udpWriter, serverDisco key.DiscoPublic) {
senderRaw, isDiscoMsg := disco.Source(b)
if !isDiscoMsg {
// Not a Disco message
return
}
sender := key.DiscoPublicFromRaw32(mem.B(senderRaw))
senderIndex := -1
switch {
case sender.Compare(e.discoPubKeys[0]) == 0:
senderIndex = 0
case sender.Compare(e.discoPubKeys[1]) == 0:
senderIndex = 1
default:
// unknown Disco public key
return
}
const headerLen = len(disco.Magic) + key.DiscoPublicRawLen
discoPayload, ok := e.discoSharedSecrets[senderIndex].Open(b[headerLen:])
if !ok {
// unable to decrypt the Disco payload
return
}
discoMsg, err := disco.Parse(discoPayload)
if err != nil {
// unable to parse the Disco payload
return
}
e.handleDiscoControlMsg(from, senderIndex, discoMsg, uw, serverDisco)
}
type udpWriter interface {
WriteMsgUDPAddrPort(b []byte, oob []byte, addr netip.AddrPort) (n, oobn int, err error)
}
func (e *serverEndpoint) handlePacket(from netip.AddrPort, gh packet.GeneveHeader, b []byte, uw udpWriter, serverDisco key.DiscoPublic) {
if !gh.Control {
if !e.isBound() {
// not a control packet, but serverEndpoint isn't bound
return
}
var to netip.AddrPort
switch {
case from == e.addrPorts[0]:
e.lastSeen[0] = time.Now()
to = e.addrPorts[1]
case from == e.addrPorts[1]:
e.lastSeen[1] = time.Now()
to = e.addrPorts[0]
default:
// unrecognized source
return
}
// relay packet
uw.WriteMsgUDPAddrPort(b, nil, to)
return
}
if e.isBound() {
// control packet, but serverEndpoint is already bound
return
}
if gh.Protocol != packet.GeneveProtocolDisco {
// control packet, but not Disco
return
}
msg := b[packet.GeneveFixedHeaderLength:]
e.handleSealedDiscoControlMsg(from, msg, uw, serverDisco)
}
func (e *serverEndpoint) isExpired(now time.Time, bindLifetime, steadyStateLifetime time.Duration) bool {
if !e.isBound() {
if now.Sub(e.allocatedAt) > bindLifetime {
return true
}
return false
}
if now.Sub(e.lastSeen[0]) > steadyStateLifetime || now.Sub(e.lastSeen[1]) > steadyStateLifetime {
return true
}
return false
}
// isBound returns true if both clients have completed their 3-way handshake,
// otherwise false.
func (e *serverEndpoint) isBound() bool {
return e.handshakeState[0] == disco.BindUDPRelayHandshakeStateAnswerReceived &&
e.handshakeState[1] == disco.BindUDPRelayHandshakeStateAnswerReceived
}
// NewServer constructs a Server listening on 0.0.0.0:'port'. IPv6 is not yet
// supported. Port may be 0, and what ultimately gets bound is returned as
// 'boundPort'. Supplied 'addrs' are joined with 'boundPort' and returned as
// ServerEndpoint.AddrPorts in response to Server.AllocateEndpoint() requests.
//
// TODO: IPv6 support
// TODO: dynamic addrs:port discovery
func NewServer(port int, addrs []netip.Addr) (s *Server, boundPort int, err error) {
s = &Server{
disco: key.NewDisco(),
bindLifetime: defaultBindLifetime,
steadyStateLifetime: defaultSteadyStateLifetime,
closeCh: make(chan struct{}),
byDisco: make(map[pairOfDiscoPubKeys]*serverEndpoint),
byVNI: make(map[uint32]*serverEndpoint),
}
s.discoPublic = s.disco.Public()
// TODO: instead of allocating 10s of MBs for the full pool, allocate
// smaller chunks and increase as needed
s.vniPool = make([]uint32, 0, 1<<24-1)
for i := 1; i < 1<<24; i++ {
s.vniPool = append(s.vniPool, uint32(i))
}
boundPort, err = s.listenOn(port)
if err != nil {
return nil, 0, err
}
addrPorts := make([]netip.AddrPort, 0, len(addrs))
for _, addr := range addrs {
addrPort, err := netip.ParseAddrPort(net.JoinHostPort(addr.String(), strconv.Itoa(boundPort)))
if err != nil {
return nil, 0, err
}
addrPorts = append(addrPorts, addrPort)
}
s.addrPorts = addrPorts
s.wg.Add(2)
go s.packetReadLoop()
go s.endpointGCLoop()
return s, boundPort, nil
}
func (s *Server) listenOn(port int) (int, error) {
uc, err := net.ListenUDP("udp4", &net.UDPAddr{Port: port})
if err != nil {
return 0, err
}
// TODO: set IP_PKTINFO sockopt
_, boundPortStr, err := net.SplitHostPort(uc.LocalAddr().String())
if err != nil {
s.uc.Close()
return 0, err
}
boundPort, err := strconv.Atoi(boundPortStr)
if err != nil {
s.uc.Close()
return 0, err
}
s.uc = uc
return boundPort, nil
}
// Close closes the server.
func (s *Server) Close() error {
s.closeOnce.Do(func() {
s.mu.Lock()
defer s.mu.Unlock()
s.uc.Close()
close(s.closeCh)
s.wg.Wait()
clear(s.byVNI)
clear(s.byDisco)
s.vniPool = nil
s.closed = true
})
return nil
}
func (s *Server) endpointGCLoop() {
defer s.wg.Done()
ticker := time.NewTicker(s.bindLifetime)
defer ticker.Stop()
gc := func() {
now := time.Now()
// TODO: consider performance implications of scanning all endpoints and
// holding s.mu for the duration. Keep it simple (and slow) for now.
s.mu.Lock()
defer s.mu.Unlock()
for k, v := range s.byDisco {
if v.isExpired(now, s.bindLifetime, s.steadyStateLifetime) {
delete(s.byDisco, k)
delete(s.byVNI, v.vni)
s.vniPool = append(s.vniPool, v.vni)
}
}
}
for {
select {
case <-ticker.C:
gc()
case <-s.closeCh:
return
}
}
}
func (s *Server) handlePacket(from netip.AddrPort, b []byte, uw udpWriter) {
gh := packet.GeneveHeader{}
err := gh.Decode(b)
if err != nil {
return
}
// TODO: consider performance implications of holding s.mu for the remainder
// of this method, which does a bunch of disco/crypto work depending. Keep
// it simple (and slow) for now.
s.mu.Lock()
defer s.mu.Unlock()
e, ok := s.byVNI[gh.VNI]
if !ok {
// unknown VNI
return
}
e.handlePacket(from, gh, b, uw, s.discoPublic)
}
func (s *Server) packetReadLoop() {
defer func() {
s.wg.Done()
s.Close()
}()
b := make([]byte, 1<<16-1)
for {
// TODO: extract laddr from IP_PKTINFO for use in reply
n, from, err := s.uc.ReadFromUDPAddrPort(b)
if err != nil {
return
}
s.handlePacket(from, b[:n], s.uc)
}
}
var ErrServerClosed = errors.New("server closed")
// AllocateEndpoint allocates a ServerEndpoint for the provided pair of
// key.DiscoPublic's. It returns an error (ErrServerClosed) if the server has
// been closed.
func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (ServerEndpoint, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.closed {
return ServerEndpoint{}, ErrServerClosed
}
if discoA.Compare(s.discoPublic) == 0 || discoB.Compare(s.discoPublic) == 0 {
return ServerEndpoint{}, fmt.Errorf("client disco equals server disco: %s", s.discoPublic.ShortString())
}
pair := newPairOfDiscoPubKeys(discoA, discoB)
e, ok := s.byDisco[pair]
if ok {
if !e.isBound() {
// If the endpoint is not yet bound this is likely an allocation
// race between two clients on the same Server. Instead of
// re-allocating we return the existing allocation. We do not reset
// e.allocatedAt in case a client is "stuck" in an allocation
// loop and will not be able to complete a handshake, for whatever
// reason. Once the endpoint expires a new endpoint will be
// allocated. Clients can resolve duplicate ServerEndpoint details
// via ServerEndpoint.LamportID.
//
// TODO: consider ServerEndpoint.BindLifetime -= time.Now()-e.allocatedAt
// to give the client a more accurate picture of the bind window.
// Or, some threshold to trigger re-allocation if too much time has
// already passed since it was originally allocated.
return ServerEndpoint{
ServerDisco: s.discoPublic,
AddrPorts: s.addrPorts,
VNI: e.vni,
LamportID: e.lamportID,
BindLifetime: s.bindLifetime,
SteadyStateLifetime: s.steadyStateLifetime,
}, nil
}
// If an endpoint exists for the pair of key.DiscoPublic's, and is
// already bound, delete it. We will re-allocate a new endpoint. Chances
// are clients cannot make use of the existing, bound allocation if
// they are requesting a new one.
delete(s.byDisco, pair)
delete(s.byVNI, e.vni)
s.vniPool = append(s.vniPool, e.vni)
}
if len(s.vniPool) == 0 {
return ServerEndpoint{}, errors.New("VNI pool exhausted")
}
s.lamportID++
e = &serverEndpoint{
discoPubKeys: pair,
lamportID: s.lamportID,
allocatedAt: time.Now(),
}
e.discoSharedSecrets[0] = s.disco.Shared(e.discoPubKeys[0])
e.discoSharedSecrets[1] = s.disco.Shared(e.discoPubKeys[1])
e.vni, s.vniPool = s.vniPool[0], s.vniPool[1:]
rand.Read(e.challenge[0][:])
rand.Read(e.challenge[1][:])
s.byDisco[pair] = e
s.byVNI[e.vni] = e
return ServerEndpoint{
ServerDisco: s.discoPublic,
AddrPorts: s.addrPorts,
VNI: e.vni,
LamportID: e.lamportID,
BindLifetime: s.bindLifetime,
SteadyStateLifetime: s.steadyStateLifetime,
}, nil
}