From fb8561ee9dc67b2c1c32b858ce5e38658828a65a Mon Sep 17 00:00:00 2001 From: Jordan Whited Date: Thu, 13 Mar 2025 16:11:10 -0700 Subject: [PATCH] net/udprelay: start of UDP relay server implementation Updates tailscale/corp#27101 Signed-off-by: Jordan Whited --- disco/disco.go | 125 +++++++++++- net/udprelay/server.go | 422 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 544 insertions(+), 3 deletions(-) create mode 100644 net/udprelay/server.go diff --git a/disco/disco.go b/disco/disco.go index b9a90029d..dae75d2fd 100644 --- a/disco/disco.go +++ b/disco/disco.go @@ -41,9 +41,12 @@ const NonceLen = 24 type MessageType byte const ( - TypePing = MessageType(0x01) - TypePong = MessageType(0x02) - TypeCallMeMaybe = MessageType(0x03) + TypePing = MessageType(0x01) + TypePong = MessageType(0x02) + TypeCallMeMaybe = MessageType(0x03) + TypeBindUDPEndpoint = MessageType(0x04) + TypeBindUDPEndpointChallenge = MessageType(0x05) + TypeBindUDPEndpointAnswer = MessageType(0x06) ) const v0 = byte(0) @@ -83,6 +86,12 @@ func Parse(p []byte) (Message, error) { return parsePong(ver, p) case TypeCallMeMaybe: return parseCallMeMaybe(ver, p) + case TypeBindUDPEndpoint: + return parseBindUDPEndpoint(ver, p) + case TypeBindUDPEndpointChallenge: + return parseBindUDPEndpointChallenge(ver, p) + case TypeBindUDPEndpointAnswer: + return parseBindUDPEndpointAnswer(ver, p) default: return nil, fmt.Errorf("unknown message type 0x%02x", byte(t)) } @@ -266,7 +275,117 @@ func MessageSummary(m Message) string { return fmt.Sprintf("pong tx=%x", m.TxID[:6]) case *CallMeMaybe: return "call-me-maybe" + case *BindUDPEndpoint: + return "bind-udp-endpoint" + case *BindUDPEndpointChallenge: + return "bind-udp-endpoint-challenge" + case *BindUDPEndpointAnswer: + return "bind-udp-endpoint-answer" default: return fmt.Sprintf("%#v", m) } } + +// BindHandshakeState represents the state of the 3-way bind handshake between +// UDP relay client and UDP relay server. Its potential values (constants below) +// include those for both sides of the handshake, UDP relay client and UDP relay +// server. A UDP relay server implementation exists in net/udprelay. This is +// currently considered experimental. +type BindHandshakeState int + +const ( + // BindHandshakeStateInit represents the initial state prior to any message + // being transmitted. + BindHandshakeStateInit BindHandshakeState = iota + // BindHandshakeStateBindSent is a potential UDP relay client state once it + // has transmitted a BindUDPEndpoint message towards a UDP relay server. + BindHandshakeStateBindSent + // BindHandshakeStateChallengeSent is a potential UDP relay server state + // once it has transmitted a BindUDPEndpointChallenge message towards a UDP + // relay client in response to a BindUDPEndpoint message. + BindHandshakeStateChallengeSent + // BindHandshakeStateAnswerSent is a potential UDP relay client state once + // it has transmitted a BindUDPEndpointAnswer message towards a UDP relay + // server in response to a BindUDPEndpointChallenge message. + BindHandshakeStateAnswerSent + // BindHandshakeStateAnswerReceived is a potential UDP relay server state + // once it has received a valid/correct BindUDPEndpointAnswer message from a + // UDP relay client in response to a BindUDPEndpointChallenge message. + BindHandshakeStateAnswerReceived +) + +// bindUDPEndpointLen is the length of a marshalled BindUDPEndpoint message, +// without the message header. +const bindUDPEndpointLen = BindUDPEndpointChallengeLen + +// BindUDPEndpoint is the first messaged transmitted from UDP relay client +// towards UDP relay server as part of the 3-way bind handshake. It is padded to +// match the length of BindUDPEndpointChallenge. This message type is currently +// considered experimental and is not yet tied to a tailcfg.CapabilityVersion. +type BindUDPEndpoint struct { + padding [bindUDPEndpointLen]byte +} + +func (m *BindUDPEndpoint) AppendMarshal(b []byte) []byte { + ret, _ := appendMsgHeader(b, TypeBindUDPEndpoint, v0, 0) + return ret +} + +func parseBindUDPEndpoint(ver uint8, p []byte) (m *BindUDPEndpoint, err error) { + m = new(BindUDPEndpoint) + return m, nil +} + +// BindUDPEndpointChallengeLen is the length of a marshalled +// BindUDPEndpointChallenge message, without the message header. +const BindUDPEndpointChallengeLen = 32 + +// BindUDPEndpointChallenge is transmitted from UDP relay server towards UDP +// relay client in response to a BindUDPEndpoint message as part of the 3-way +// bind handshake. This message type is currently considered experimental and is +// not yet tied to a tailcfg.CapabilityVersion. +type BindUDPEndpointChallenge struct { + Challenge [BindUDPEndpointChallengeLen]byte +} + +func (m *BindUDPEndpointChallenge) AppendMarshal(b []byte) []byte { + ret, d := appendMsgHeader(b, TypeBindUDPEndpointChallenge, v0, BindUDPEndpointChallengeLen) + copy(d, m.Challenge[:]) + return ret +} + +func parseBindUDPEndpointChallenge(ver uint8, p []byte) (m *BindUDPEndpointChallenge, err error) { + if len(p) < BindUDPEndpointChallengeLen { + return nil, errShort + } + m = new(BindUDPEndpointChallenge) + copy(m.Challenge[:], p[:]) + return m, nil +} + +// bindUDPEndpointAnswerLen is the length of a marshalled +// BindUDPEndpointAnswer message, without the message header. +const bindUDPEndpointAnswerLen = BindUDPEndpointChallengeLen + +// BindUDPEndpointAnswer is transmitted from UDP relay client to UDP relay +// server in response to a BindUDPEndpointChallenge message. This message type +// is currently considered experimental and is not yet tied to a +// tailcfg.CapabilityVersion. +type BindUDPEndpointAnswer struct { + Answer [bindUDPEndpointAnswerLen]byte +} + +func (m *BindUDPEndpointAnswer) AppendMarshal(b []byte) []byte { + ret, d := appendMsgHeader(b, TypeBindUDPEndpointAnswer, v0, bindUDPEndpointAnswerLen) + copy(d, m.Answer[:]) + return ret +} + +func parseBindUDPEndpointAnswer(ver uint8, p []byte) (m *BindUDPEndpointAnswer, err error) { + if len(p) < bindUDPEndpointAnswerLen { + return nil, errShort + } + m = new(BindUDPEndpointAnswer) + copy(m.Answer[:], p[:]) + return m, nil +} diff --git a/net/udprelay/server.go b/net/udprelay/server.go new file mode 100644 index 000000000..7c207fd4c --- /dev/null +++ b/net/udprelay/server.go @@ -0,0 +1,422 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package udprelay + +import ( + "bytes" + "crypto/rand" + "errors" + "fmt" + "net" + "net/netip" + "strconv" + "sync" + "time" + + "go4.org/mem" + "tailscale.com/disco" + "tailscale.com/net/packet" + "tailscale.com/types/key" +) + +const ( + defaultBindLifetime = time.Second * 5 + defaultSteadyStateLifetime = time.Minute * 5 +) + +// Server implements a UDP relay server. +type Server struct { + // disco keypair used as part of 3-way bind handshake + disco key.DiscoPrivate + discoPublic key.DiscoPublic + + bindLifetime time.Duration + steadyStateLifetime time.Duration + + // addrPorts contains the ip:port pairs returned as candidate server + // endpoints in response to an allocation request. + addrPorts []netip.AddrPort + + uc *net.UDPConn + + closeOnce sync.Once + wg sync.WaitGroup + closeCh chan struct{} + closed bool + + mu sync.Mutex // guards the following fields + vniPool []uint32 // the pool of available VNIs + byVNI map[uint32]*serverEndpoint + byDisco map[pairOfDiscoPubKeys]*serverEndpoint +} + +// pairOfDiscoPubKeys is a pair of key.DiscoPublic. It must be constructed via +// newPairOfDiscoPubKeys to ensure lexicographical ordering. +type pairOfDiscoPubKeys [2]key.DiscoPublic + +func (p pairOfDiscoPubKeys) String() string { + return fmt.Sprintf("%s <=> %s", p[0].ShortString(), p[1].ShortString()) +} + +func newPairOfDiscoPubKeys(discoA, discoB key.DiscoPublic) pairOfDiscoPubKeys { + var pair pairOfDiscoPubKeys + cmp := discoA.Compare(discoB) + if cmp == 1 { + pair[0] = discoB + pair[1] = discoA + } else { + pair[0] = discoA + pair[1] = discoB + } + return pair +} + +// ServerEndpoint contains the Server's endpoint details. +type ServerEndpoint struct { + // ServerDisco is the Server's Disco public key used as part of the 3-way + // bind handshake. + ServerDisco key.DiscoPublic + + // AddrPorts are the IP:Port candidate pairs the Server may be reachable + // over. + AddrPorts []netip.AddrPort + + // VNI (Virtual Network Identifier) is the Geneve header VNI the Server + // expects for associated packets. + VNI uint32 + + // BindLifetime is amount of time post-allocation the Server will keep the + // endpoint alive while it has yet to be bound. + BindLifetime time.Duration + + // SteadyStateLifetime is the amount of time post-bind the Server will keep + // the endpoint alive lacking bidirectional data flow. + SteadyStateLifetime time.Duration +} + +type serverEndpoint struct { + discoPubKeys pairOfDiscoPubKeys + discoSharedSecrets [2]key.DiscoShared + handeshakeState [2]disco.BindHandshakeState + addrPorts [2]netip.AddrPort + lastSeen [2]time.Time + challenge [2][disco.BindUDPEndpointChallengeLen]byte + vni uint32 + allocatedAt time.Time +} + +func (e *serverEndpoint) expired(now time.Time, bindLifetime, steadyStateLifetime time.Duration) bool { + if !e.bound() { + if now.Sub(e.allocatedAt) > bindLifetime { + return true + } + return false + } + if now.Sub(e.lastSeen[0]) > steadyStateLifetime || now.Sub(e.lastSeen[1]) > steadyStateLifetime { + return true + } + return false +} + +// bound returns true if both clients have completed their 3-way handshake, +// otherwise false. +func (e *serverEndpoint) bound() bool { + return e.handeshakeState[0] == disco.BindHandshakeStateAnswerReceived && + e.handeshakeState[1] == disco.BindHandshakeStateAnswerReceived +} + +// NewServer constructs a Server listening on port and returning addrs:port +// in response to allocation requests. +func NewServer(port uint16, addrs []netip.Addr) (*Server, error) { + s := &Server{ + disco: key.NewDisco(), + bindLifetime: defaultBindLifetime, + steadyStateLifetime: defaultSteadyStateLifetime, + closeCh: make(chan struct{}), + } + s.discoPublic = s.disco.Public() + addrPorts := make([]netip.AddrPort, 0, len(addrs)) + for _, addr := range addrs { + addrPort, err := netip.ParseAddrPort(net.JoinHostPort(addr.String(), strconv.Itoa(int(port)))) + if err != nil { + return nil, err + } + addrPorts = append(addrPorts, addrPort) + } + s.addrPorts = addrPorts + // TODO: instead of allocating 10s of MBs for the full pool, allocate + // smaller chunks and increase only if needed + s.vniPool = make([]uint32, 0, 1<<24-1) + for i := 1; i < 1<<24; i++ { + s.vniPool = append(s.vniPool, uint32(i)) + } + // TODO: this assumes multi-af socket capability, but we should probably + // bind explicit ipv4 and ipv6 sockets. + uc, err := net.ListenUDP("udp", &net.UDPAddr{Port: int(port)}) + if err != nil { + return nil, err + } + s.uc = uc + s.wg.Add(2) + go s.packetReadLoop() + go s.endpointGC() + return s, nil +} + +// Close closes the server. +func (s *Server) Close() error { + s.closeOnce.Do(func() { + s.mu.Lock() + defer s.mu.Unlock() + s.uc.Close() + close(s.closeCh) + s.wg.Wait() + clear(s.byVNI) + clear(s.byDisco) + s.vniPool = nil + s.closed = true + }) + return nil +} + +func (s *Server) endpointGC() { + defer s.wg.Done() + ticker := time.NewTicker(s.bindLifetime) + defer ticker.Stop() + + gc := func() { + now := time.Now() + s.mu.Lock() + defer s.mu.Unlock() + for k, v := range s.byDisco { + if v.expired(now, s.bindLifetime, s.steadyStateLifetime) { + delete(s.byDisco, k) + delete(s.byVNI, v.vni) + s.vniPool = append(s.vniPool, v.vni) + } + } + } + + for { + select { + case <-ticker.C: + gc() + case <-s.closeCh: + return + } + } +} + +func (s *Server) handlePacket(from netip.AddrPort, b []byte) { + gh := packet.GeneveHeader{} + err := gh.Decode(b) + if err != nil { + return + } + s.mu.Lock() + defer s.mu.Unlock() + e, ok := s.byVNI[gh.VNI] + if !ok { + // unknown VNI + return + } + + if !gh.Control { + if !e.bound() { + // not a control packet, but serverEndpoint isn't bound + return + } + var to netip.AddrPort + switch { + case from == e.addrPorts[0]: + to = e.addrPorts[1] + case from == e.addrPorts[1]: + to = e.addrPorts[0] + default: + // unrecognized source + return + } + // relay packet + s.uc.WriteMsgUDPAddrPort(b, nil, to) + return + } + + if e.bound() { + // control packet, but serverEndpoint is already bound + return + } + + if gh.Protocol != packet.GeneveProtocolDisco { + // control packet, but not Disco + return + } + + msg := b[packet.GeneveFixedHeaderLength:] + senderRaw, isDiscoMsg := disco.Source(msg) + if !isDiscoMsg { + // Geneve header protocol field indicated it was Disco, but it's not + return + } + sender := key.DiscoPublicFromRaw32(mem.B(senderRaw)) + senderIndex := -1 + switch { + case sender.Compare(e.discoPubKeys[0]) == 0: + senderIndex = 0 + case sender.Compare(e.discoPubKeys[1]) == 0: + senderIndex = 1 + default: + // unknown Disco public key + return + } + + const headerLen = len(disco.Magic) + key.DiscoPublicRawLen + discoPayload, ok := e.discoSharedSecrets[senderIndex].Open(msg[headerLen:]) + if !ok { + // unable to decrypt the disco payload + return + } + + discoMsg, err := disco.Parse(discoPayload) + if err != nil { + // unable to parse the disco payload + } + + handshakeState := e.handeshakeState[senderIndex] + if handshakeState == disco.BindHandshakeStateAnswerReceived { + // this sender is already bound + return + } + switch discoMsg := discoMsg.(type) { + case *disco.BindUDPEndpoint: + switch handshakeState { + case disco.BindHandshakeStateInit: + // generate a challenge, maybe we should do this at allocation time? + rand.Read(e.challenge[senderIndex][:]) + // set sender addr + e.addrPorts[senderIndex] = from + fallthrough + case disco.BindHandshakeStateChallengeSent: + if from != e.addrPorts[senderIndex] { + // this is a later arriving bind from a different source, discard + return + } + m := new(disco.BindUDPEndpointChallenge) + copy(m.Challenge[:], e.challenge[senderIndex][:]) + reply := make([]byte, packet.GeneveFixedHeaderLength, 512) + err = gh.Encode(reply) + if err != nil { + return + } + reply = append(reply, disco.Magic...) + reply = s.discoPublic.AppendTo(reply) + box := e.discoSharedSecrets[senderIndex].Seal(m.AppendMarshal(nil)) + reply = append(reply, box...) + s.uc.WriteMsgUDPAddrPort(reply, nil, from) + // set new state + e.handeshakeState[senderIndex] = disco.BindHandshakeStateChallengeSent + return + default: + // disco.BindUDPEndpoint is unexpected in all other handshake states + return + } + case *disco.BindUDPEndpointAnswer: + switch handshakeState { + case disco.BindHandshakeStateChallengeSent: + if from != e.addrPorts[senderIndex] { + // sender source has changed + return + } + if !bytes.Equal(discoMsg.Answer[:], e.challenge[senderIndex][:]) { + // bad answer + return + } + // sender is now bound + e.handeshakeState[senderIndex] = disco.BindHandshakeStateAnswerReceived + // record last seen as bound time + e.lastSeen[senderIndex] = time.Now() + return + default: + // disco.BindUDPEndpointAnswer is unexpected in all other handshake + // states, or we've already handled it + return + } + default: + // unexpected Disco message type + return + } +} + +func (s *Server) packetReadLoop() { + defer func() { + s.wg.Done() + s.Close() + }() + b := make([]byte, 1<<16-1) + for { + n, from, err := s.uc.ReadFromUDPAddrPort(b) + if err != nil { + return + } + s.handlePacket(from, b[:n]) + } +} + +var ErrServerClosed = errors.New("server closed") + +// AllocateEndpoint allocates a ServerEndpoint for the provided pair of +// key.DiscoPublic's. It returns ErrServerClosed if the server has been closed. +func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (ServerEndpoint, error) { + s.mu.Lock() + defer s.mu.Unlock() + if s.closed { + return ServerEndpoint{}, ErrServerClosed + } + + pair := newPairOfDiscoPubKeys(discoA, discoB) + e, ok := s.byDisco[pair] + if ok { + if !e.bound() { + // If the endpoint is not yet bound this is likely an allocation + // race between two clients utilizing the same relay. Instead of + // re-allocating we return the existing allocation state, and reset + // the allocation time. + e.allocatedAt = time.Now() + return ServerEndpoint{ + ServerDisco: s.discoPublic, + AddrPorts: s.addrPorts, + VNI: e.vni, + BindLifetime: s.bindLifetime, + SteadyStateLifetime: s.steadyStateLifetime, + }, nil + } + // If an endpoint exists for the pair of key.DiscoPublic's, and is + // already bound, delete it. We will re-allocate a new endpoint. Chances + // are clients cannot make use of the existing, bound allocation if + // they are requesting a new one. + delete(s.byDisco, pair) + delete(s.byVNI, e.vni) + s.vniPool = append(s.vniPool, e.vni) + } + + if len(s.vniPool) == 0 { + return ServerEndpoint{}, errors.New("VNI pool exhausted") + } + + e = &serverEndpoint{ + discoPubKeys: pair, + allocatedAt: time.Now(), + } + e.discoSharedSecrets[0] = s.disco.Shared(e.discoPubKeys[0]) + e.discoSharedSecrets[1] = s.disco.Shared(e.discoPubKeys[1]) + e.vni, s.vniPool = s.vniPool[0], s.vniPool[1:] + s.byDisco[pair] = e + s.byVNI[e.vni] = e + + return ServerEndpoint{ + AddrPorts: s.addrPorts, + VNI: e.vni, + BindLifetime: defaultBindLifetime, + SteadyStateLifetime: defaultSteadyStateLifetime, + }, nil +}