net/udprelay: start of UDP relay server implementation

Updates tailscale/corp#27101

Signed-off-by: Jordan Whited <jordan@tailscale.com>
This commit is contained in:
Jordan Whited 2025-03-13 16:11:10 -07:00
parent 3a4b622276
commit fb8561ee9d
No known key found for this signature in database
GPG Key ID: 33DF352F65991EB8
2 changed files with 544 additions and 3 deletions

View File

@ -41,9 +41,12 @@ const NonceLen = 24
type MessageType byte
const (
TypePing = MessageType(0x01)
TypePong = MessageType(0x02)
TypeCallMeMaybe = MessageType(0x03)
TypePing = MessageType(0x01)
TypePong = MessageType(0x02)
TypeCallMeMaybe = MessageType(0x03)
TypeBindUDPEndpoint = MessageType(0x04)
TypeBindUDPEndpointChallenge = MessageType(0x05)
TypeBindUDPEndpointAnswer = MessageType(0x06)
)
const v0 = byte(0)
@ -83,6 +86,12 @@ func Parse(p []byte) (Message, error) {
return parsePong(ver, p)
case TypeCallMeMaybe:
return parseCallMeMaybe(ver, p)
case TypeBindUDPEndpoint:
return parseBindUDPEndpoint(ver, p)
case TypeBindUDPEndpointChallenge:
return parseBindUDPEndpointChallenge(ver, p)
case TypeBindUDPEndpointAnswer:
return parseBindUDPEndpointAnswer(ver, p)
default:
return nil, fmt.Errorf("unknown message type 0x%02x", byte(t))
}
@ -266,7 +275,117 @@ func MessageSummary(m Message) string {
return fmt.Sprintf("pong tx=%x", m.TxID[:6])
case *CallMeMaybe:
return "call-me-maybe"
case *BindUDPEndpoint:
return "bind-udp-endpoint"
case *BindUDPEndpointChallenge:
return "bind-udp-endpoint-challenge"
case *BindUDPEndpointAnswer:
return "bind-udp-endpoint-answer"
default:
return fmt.Sprintf("%#v", m)
}
}
// BindHandshakeState represents the state of the 3-way bind handshake between
// UDP relay client and UDP relay server. Its potential values (constants below)
// include those for both sides of the handshake, UDP relay client and UDP relay
// server. A UDP relay server implementation exists in net/udprelay. This is
// currently considered experimental.
type BindHandshakeState int
const (
// BindHandshakeStateInit represents the initial state prior to any message
// being transmitted.
BindHandshakeStateInit BindHandshakeState = iota
// BindHandshakeStateBindSent is a potential UDP relay client state once it
// has transmitted a BindUDPEndpoint message towards a UDP relay server.
BindHandshakeStateBindSent
// BindHandshakeStateChallengeSent is a potential UDP relay server state
// once it has transmitted a BindUDPEndpointChallenge message towards a UDP
// relay client in response to a BindUDPEndpoint message.
BindHandshakeStateChallengeSent
// BindHandshakeStateAnswerSent is a potential UDP relay client state once
// it has transmitted a BindUDPEndpointAnswer message towards a UDP relay
// server in response to a BindUDPEndpointChallenge message.
BindHandshakeStateAnswerSent
// BindHandshakeStateAnswerReceived is a potential UDP relay server state
// once it has received a valid/correct BindUDPEndpointAnswer message from a
// UDP relay client in response to a BindUDPEndpointChallenge message.
BindHandshakeStateAnswerReceived
)
// bindUDPEndpointLen is the length of a marshalled BindUDPEndpoint message,
// without the message header.
const bindUDPEndpointLen = BindUDPEndpointChallengeLen
// BindUDPEndpoint is the first messaged transmitted from UDP relay client
// towards UDP relay server as part of the 3-way bind handshake. It is padded to
// match the length of BindUDPEndpointChallenge. This message type is currently
// considered experimental and is not yet tied to a tailcfg.CapabilityVersion.
type BindUDPEndpoint struct {
padding [bindUDPEndpointLen]byte
}
func (m *BindUDPEndpoint) AppendMarshal(b []byte) []byte {
ret, _ := appendMsgHeader(b, TypeBindUDPEndpoint, v0, 0)
return ret
}
func parseBindUDPEndpoint(ver uint8, p []byte) (m *BindUDPEndpoint, err error) {
m = new(BindUDPEndpoint)
return m, nil
}
// BindUDPEndpointChallengeLen is the length of a marshalled
// BindUDPEndpointChallenge message, without the message header.
const BindUDPEndpointChallengeLen = 32
// BindUDPEndpointChallenge is transmitted from UDP relay server towards UDP
// relay client in response to a BindUDPEndpoint message as part of the 3-way
// bind handshake. This message type is currently considered experimental and is
// not yet tied to a tailcfg.CapabilityVersion.
type BindUDPEndpointChallenge struct {
Challenge [BindUDPEndpointChallengeLen]byte
}
func (m *BindUDPEndpointChallenge) AppendMarshal(b []byte) []byte {
ret, d := appendMsgHeader(b, TypeBindUDPEndpointChallenge, v0, BindUDPEndpointChallengeLen)
copy(d, m.Challenge[:])
return ret
}
func parseBindUDPEndpointChallenge(ver uint8, p []byte) (m *BindUDPEndpointChallenge, err error) {
if len(p) < BindUDPEndpointChallengeLen {
return nil, errShort
}
m = new(BindUDPEndpointChallenge)
copy(m.Challenge[:], p[:])
return m, nil
}
// bindUDPEndpointAnswerLen is the length of a marshalled
// BindUDPEndpointAnswer message, without the message header.
const bindUDPEndpointAnswerLen = BindUDPEndpointChallengeLen
// BindUDPEndpointAnswer is transmitted from UDP relay client to UDP relay
// server in response to a BindUDPEndpointChallenge message. This message type
// is currently considered experimental and is not yet tied to a
// tailcfg.CapabilityVersion.
type BindUDPEndpointAnswer struct {
Answer [bindUDPEndpointAnswerLen]byte
}
func (m *BindUDPEndpointAnswer) AppendMarshal(b []byte) []byte {
ret, d := appendMsgHeader(b, TypeBindUDPEndpointAnswer, v0, bindUDPEndpointAnswerLen)
copy(d, m.Answer[:])
return ret
}
func parseBindUDPEndpointAnswer(ver uint8, p []byte) (m *BindUDPEndpointAnswer, err error) {
if len(p) < bindUDPEndpointAnswerLen {
return nil, errShort
}
m = new(BindUDPEndpointAnswer)
copy(m.Answer[:], p[:])
return m, nil
}

422
net/udprelay/server.go Normal file
View File

@ -0,0 +1,422 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package udprelay
import (
"bytes"
"crypto/rand"
"errors"
"fmt"
"net"
"net/netip"
"strconv"
"sync"
"time"
"go4.org/mem"
"tailscale.com/disco"
"tailscale.com/net/packet"
"tailscale.com/types/key"
)
const (
defaultBindLifetime = time.Second * 5
defaultSteadyStateLifetime = time.Minute * 5
)
// Server implements a UDP relay server.
type Server struct {
// disco keypair used as part of 3-way bind handshake
disco key.DiscoPrivate
discoPublic key.DiscoPublic
bindLifetime time.Duration
steadyStateLifetime time.Duration
// addrPorts contains the ip:port pairs returned as candidate server
// endpoints in response to an allocation request.
addrPorts []netip.AddrPort
uc *net.UDPConn
closeOnce sync.Once
wg sync.WaitGroup
closeCh chan struct{}
closed bool
mu sync.Mutex // guards the following fields
vniPool []uint32 // the pool of available VNIs
byVNI map[uint32]*serverEndpoint
byDisco map[pairOfDiscoPubKeys]*serverEndpoint
}
// pairOfDiscoPubKeys is a pair of key.DiscoPublic. It must be constructed via
// newPairOfDiscoPubKeys to ensure lexicographical ordering.
type pairOfDiscoPubKeys [2]key.DiscoPublic
func (p pairOfDiscoPubKeys) String() string {
return fmt.Sprintf("%s <=> %s", p[0].ShortString(), p[1].ShortString())
}
func newPairOfDiscoPubKeys(discoA, discoB key.DiscoPublic) pairOfDiscoPubKeys {
var pair pairOfDiscoPubKeys
cmp := discoA.Compare(discoB)
if cmp == 1 {
pair[0] = discoB
pair[1] = discoA
} else {
pair[0] = discoA
pair[1] = discoB
}
return pair
}
// ServerEndpoint contains the Server's endpoint details.
type ServerEndpoint struct {
// ServerDisco is the Server's Disco public key used as part of the 3-way
// bind handshake.
ServerDisco key.DiscoPublic
// AddrPorts are the IP:Port candidate pairs the Server may be reachable
// over.
AddrPorts []netip.AddrPort
// VNI (Virtual Network Identifier) is the Geneve header VNI the Server
// expects for associated packets.
VNI uint32
// BindLifetime is amount of time post-allocation the Server will keep the
// endpoint alive while it has yet to be bound.
BindLifetime time.Duration
// SteadyStateLifetime is the amount of time post-bind the Server will keep
// the endpoint alive lacking bidirectional data flow.
SteadyStateLifetime time.Duration
}
type serverEndpoint struct {
discoPubKeys pairOfDiscoPubKeys
discoSharedSecrets [2]key.DiscoShared
handeshakeState [2]disco.BindHandshakeState
addrPorts [2]netip.AddrPort
lastSeen [2]time.Time
challenge [2][disco.BindUDPEndpointChallengeLen]byte
vni uint32
allocatedAt time.Time
}
func (e *serverEndpoint) expired(now time.Time, bindLifetime, steadyStateLifetime time.Duration) bool {
if !e.bound() {
if now.Sub(e.allocatedAt) > bindLifetime {
return true
}
return false
}
if now.Sub(e.lastSeen[0]) > steadyStateLifetime || now.Sub(e.lastSeen[1]) > steadyStateLifetime {
return true
}
return false
}
// bound returns true if both clients have completed their 3-way handshake,
// otherwise false.
func (e *serverEndpoint) bound() bool {
return e.handeshakeState[0] == disco.BindHandshakeStateAnswerReceived &&
e.handeshakeState[1] == disco.BindHandshakeStateAnswerReceived
}
// NewServer constructs a Server listening on port and returning addrs:port
// in response to allocation requests.
func NewServer(port uint16, addrs []netip.Addr) (*Server, error) {
s := &Server{
disco: key.NewDisco(),
bindLifetime: defaultBindLifetime,
steadyStateLifetime: defaultSteadyStateLifetime,
closeCh: make(chan struct{}),
}
s.discoPublic = s.disco.Public()
addrPorts := make([]netip.AddrPort, 0, len(addrs))
for _, addr := range addrs {
addrPort, err := netip.ParseAddrPort(net.JoinHostPort(addr.String(), strconv.Itoa(int(port))))
if err != nil {
return nil, err
}
addrPorts = append(addrPorts, addrPort)
}
s.addrPorts = addrPorts
// TODO: instead of allocating 10s of MBs for the full pool, allocate
// smaller chunks and increase only if needed
s.vniPool = make([]uint32, 0, 1<<24-1)
for i := 1; i < 1<<24; i++ {
s.vniPool = append(s.vniPool, uint32(i))
}
// TODO: this assumes multi-af socket capability, but we should probably
// bind explicit ipv4 and ipv6 sockets.
uc, err := net.ListenUDP("udp", &net.UDPAddr{Port: int(port)})
if err != nil {
return nil, err
}
s.uc = uc
s.wg.Add(2)
go s.packetReadLoop()
go s.endpointGC()
return s, nil
}
// Close closes the server.
func (s *Server) Close() error {
s.closeOnce.Do(func() {
s.mu.Lock()
defer s.mu.Unlock()
s.uc.Close()
close(s.closeCh)
s.wg.Wait()
clear(s.byVNI)
clear(s.byDisco)
s.vniPool = nil
s.closed = true
})
return nil
}
func (s *Server) endpointGC() {
defer s.wg.Done()
ticker := time.NewTicker(s.bindLifetime)
defer ticker.Stop()
gc := func() {
now := time.Now()
s.mu.Lock()
defer s.mu.Unlock()
for k, v := range s.byDisco {
if v.expired(now, s.bindLifetime, s.steadyStateLifetime) {
delete(s.byDisco, k)
delete(s.byVNI, v.vni)
s.vniPool = append(s.vniPool, v.vni)
}
}
}
for {
select {
case <-ticker.C:
gc()
case <-s.closeCh:
return
}
}
}
func (s *Server) handlePacket(from netip.AddrPort, b []byte) {
gh := packet.GeneveHeader{}
err := gh.Decode(b)
if err != nil {
return
}
s.mu.Lock()
defer s.mu.Unlock()
e, ok := s.byVNI[gh.VNI]
if !ok {
// unknown VNI
return
}
if !gh.Control {
if !e.bound() {
// not a control packet, but serverEndpoint isn't bound
return
}
var to netip.AddrPort
switch {
case from == e.addrPorts[0]:
to = e.addrPorts[1]
case from == e.addrPorts[1]:
to = e.addrPorts[0]
default:
// unrecognized source
return
}
// relay packet
s.uc.WriteMsgUDPAddrPort(b, nil, to)
return
}
if e.bound() {
// control packet, but serverEndpoint is already bound
return
}
if gh.Protocol != packet.GeneveProtocolDisco {
// control packet, but not Disco
return
}
msg := b[packet.GeneveFixedHeaderLength:]
senderRaw, isDiscoMsg := disco.Source(msg)
if !isDiscoMsg {
// Geneve header protocol field indicated it was Disco, but it's not
return
}
sender := key.DiscoPublicFromRaw32(mem.B(senderRaw))
senderIndex := -1
switch {
case sender.Compare(e.discoPubKeys[0]) == 0:
senderIndex = 0
case sender.Compare(e.discoPubKeys[1]) == 0:
senderIndex = 1
default:
// unknown Disco public key
return
}
const headerLen = len(disco.Magic) + key.DiscoPublicRawLen
discoPayload, ok := e.discoSharedSecrets[senderIndex].Open(msg[headerLen:])
if !ok {
// unable to decrypt the disco payload
return
}
discoMsg, err := disco.Parse(discoPayload)
if err != nil {
// unable to parse the disco payload
}
handshakeState := e.handeshakeState[senderIndex]
if handshakeState == disco.BindHandshakeStateAnswerReceived {
// this sender is already bound
return
}
switch discoMsg := discoMsg.(type) {
case *disco.BindUDPEndpoint:
switch handshakeState {
case disco.BindHandshakeStateInit:
// generate a challenge, maybe we should do this at allocation time?
rand.Read(e.challenge[senderIndex][:])
// set sender addr
e.addrPorts[senderIndex] = from
fallthrough
case disco.BindHandshakeStateChallengeSent:
if from != e.addrPorts[senderIndex] {
// this is a later arriving bind from a different source, discard
return
}
m := new(disco.BindUDPEndpointChallenge)
copy(m.Challenge[:], e.challenge[senderIndex][:])
reply := make([]byte, packet.GeneveFixedHeaderLength, 512)
err = gh.Encode(reply)
if err != nil {
return
}
reply = append(reply, disco.Magic...)
reply = s.discoPublic.AppendTo(reply)
box := e.discoSharedSecrets[senderIndex].Seal(m.AppendMarshal(nil))
reply = append(reply, box...)
s.uc.WriteMsgUDPAddrPort(reply, nil, from)
// set new state
e.handeshakeState[senderIndex] = disco.BindHandshakeStateChallengeSent
return
default:
// disco.BindUDPEndpoint is unexpected in all other handshake states
return
}
case *disco.BindUDPEndpointAnswer:
switch handshakeState {
case disco.BindHandshakeStateChallengeSent:
if from != e.addrPorts[senderIndex] {
// sender source has changed
return
}
if !bytes.Equal(discoMsg.Answer[:], e.challenge[senderIndex][:]) {
// bad answer
return
}
// sender is now bound
e.handeshakeState[senderIndex] = disco.BindHandshakeStateAnswerReceived
// record last seen as bound time
e.lastSeen[senderIndex] = time.Now()
return
default:
// disco.BindUDPEndpointAnswer is unexpected in all other handshake
// states, or we've already handled it
return
}
default:
// unexpected Disco message type
return
}
}
func (s *Server) packetReadLoop() {
defer func() {
s.wg.Done()
s.Close()
}()
b := make([]byte, 1<<16-1)
for {
n, from, err := s.uc.ReadFromUDPAddrPort(b)
if err != nil {
return
}
s.handlePacket(from, b[:n])
}
}
var ErrServerClosed = errors.New("server closed")
// AllocateEndpoint allocates a ServerEndpoint for the provided pair of
// key.DiscoPublic's. It returns ErrServerClosed if the server has been closed.
func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (ServerEndpoint, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.closed {
return ServerEndpoint{}, ErrServerClosed
}
pair := newPairOfDiscoPubKeys(discoA, discoB)
e, ok := s.byDisco[pair]
if ok {
if !e.bound() {
// If the endpoint is not yet bound this is likely an allocation
// race between two clients utilizing the same relay. Instead of
// re-allocating we return the existing allocation state, and reset
// the allocation time.
e.allocatedAt = time.Now()
return ServerEndpoint{
ServerDisco: s.discoPublic,
AddrPorts: s.addrPorts,
VNI: e.vni,
BindLifetime: s.bindLifetime,
SteadyStateLifetime: s.steadyStateLifetime,
}, nil
}
// If an endpoint exists for the pair of key.DiscoPublic's, and is
// already bound, delete it. We will re-allocate a new endpoint. Chances
// are clients cannot make use of the existing, bound allocation if
// they are requesting a new one.
delete(s.byDisco, pair)
delete(s.byVNI, e.vni)
s.vniPool = append(s.vniPool, e.vni)
}
if len(s.vniPool) == 0 {
return ServerEndpoint{}, errors.New("VNI pool exhausted")
}
e = &serverEndpoint{
discoPubKeys: pair,
allocatedAt: time.Now(),
}
e.discoSharedSecrets[0] = s.disco.Shared(e.discoPubKeys[0])
e.discoSharedSecrets[1] = s.disco.Shared(e.discoPubKeys[1])
e.vni, s.vniPool = s.vniPool[0], s.vniPool[1:]
s.byDisco[pair] = e
s.byVNI[e.vni] = e
return ServerEndpoint{
AddrPorts: s.addrPorts,
VNI: e.vni,
BindLifetime: defaultBindLifetime,
SteadyStateLifetime: defaultSteadyStateLifetime,
}, nil
}