diff --git a/disco/disco.go b/disco/disco.go index b9a90029d..c5aa4ace2 100644 --- a/disco/disco.go +++ b/disco/disco.go @@ -41,9 +41,12 @@ const NonceLen = 24 type MessageType byte const ( - TypePing = MessageType(0x01) - TypePong = MessageType(0x02) - TypeCallMeMaybe = MessageType(0x03) + TypePing = MessageType(0x01) + TypePong = MessageType(0x02) + TypeCallMeMaybe = MessageType(0x03) + TypeBindUDPRelayEndpoint = MessageType(0x04) + TypeBindUDPRelayEndpointChallenge = MessageType(0x05) + TypeBindUDPRelayEndpointAnswer = MessageType(0x06) ) const v0 = byte(0) @@ -77,12 +80,19 @@ func Parse(p []byte) (Message, error) { } t, ver, p := MessageType(p[0]), p[1], p[2:] switch t { + // TODO(jwhited): consider using a signature matching encoding.BinaryUnmarshaler case TypePing: return parsePing(ver, p) case TypePong: return parsePong(ver, p) case TypeCallMeMaybe: return parseCallMeMaybe(ver, p) + case TypeBindUDPRelayEndpoint: + return parseBindUDPRelayEndpoint(ver, p) + case TypeBindUDPRelayEndpointChallenge: + return parseBindUDPRelayEndpointChallenge(ver, p) + case TypeBindUDPRelayEndpointAnswer: + return parseBindUDPRelayEndpointAnswer(ver, p) default: return nil, fmt.Errorf("unknown message type 0x%02x", byte(t)) } @@ -91,6 +101,7 @@ func Parse(p []byte) (Message, error) { // Message a discovery message. type Message interface { // AppendMarshal appends the message's marshaled representation. + // TODO(jwhited): consider using a signature matching encoding.BinaryAppender AppendMarshal([]byte) []byte } @@ -266,7 +277,118 @@ func MessageSummary(m Message) string { return fmt.Sprintf("pong tx=%x", m.TxID[:6]) case *CallMeMaybe: return "call-me-maybe" + case *BindUDPRelayEndpoint: + return "bind-udp-relay-endpoint" + case *BindUDPRelayEndpointChallenge: + return "bind-udp-relay-endpoint-challenge" + case *BindUDPRelayEndpointAnswer: + return "bind-udp-relay-endpoint-answer" default: return fmt.Sprintf("%#v", m) } } + +// BindUDPRelayHandshakeState represents the state of the 3-way bind handshake +// between UDP relay client and UDP relay server. Its potential values include +// those for both participants, UDP relay client and UDP relay server. A UDP +// relay server implementation can be found in net/udprelay. This is currently +// considered experimental. +type BindUDPRelayHandshakeState int + +const ( + // BindUDPRelayHandshakeStateInit represents the initial state prior to any + // message being transmitted. + BindUDPRelayHandshakeStateInit BindUDPRelayHandshakeState = iota + // BindUDPRelayHandshakeStateBindSent is the first client state after + // transmitting a BindUDPRelayEndpoint message to a UDP relay server. + BindUDPRelayHandshakeStateBindSent + // BindUDPRelayHandshakeStateChallengeSent is the first server state after + // receiving a BindUDPRelayEndpoint message from a UDP relay client and + // replying with a BindUDPRelayEndpointChallenge. + BindUDPRelayHandshakeStateChallengeSent + // BindUDPRelayHandshakeStateAnswerSent is a client state that is entered + // after transmitting a BindUDPRelayEndpointAnswer message towards a UDP + // relay server in response to a BindUDPRelayEndpointChallenge message. + BindUDPRelayHandshakeStateAnswerSent + // BindUDPRelayHandshakeStateAnswerReceived is a server state that is + // entered after it has received a correct BindUDPRelayEndpointAnswer + // message from a UDP relay client in response to a + // BindUDPRelayEndpointChallenge message. + BindUDPRelayHandshakeStateAnswerReceived +) + +// bindUDPRelayEndpointLen is the length of a marshalled BindUDPRelayEndpoint +// message, without the message header. +const bindUDPRelayEndpointLen = BindUDPRelayEndpointChallengeLen + +// BindUDPRelayEndpoint is the first messaged transmitted from UDP relay client +// towards UDP relay server as part of the 3-way bind handshake. It is padded to +// match the length of BindUDPRelayEndpointChallenge. This message type is +// currently considered experimental and is not yet tied to a +// tailcfg.CapabilityVersion. +type BindUDPRelayEndpoint struct { +} + +func (m *BindUDPRelayEndpoint) AppendMarshal(b []byte) []byte { + ret, _ := appendMsgHeader(b, TypeBindUDPRelayEndpoint, v0, bindUDPRelayEndpointLen) + return ret +} + +func parseBindUDPRelayEndpoint(ver uint8, p []byte) (m *BindUDPRelayEndpoint, err error) { + m = new(BindUDPRelayEndpoint) + return m, nil +} + +// BindUDPRelayEndpointChallengeLen is the length of a marshalled +// BindUDPRelayEndpointChallenge message, without the message header. +const BindUDPRelayEndpointChallengeLen = 32 + +// BindUDPRelayEndpointChallenge is transmitted from UDP relay server towards +// UDP relay client in response to a BindUDPRelayEndpoint message as part of the +// 3-way bind handshake. This message type is currently considered experimental +// and is not yet tied to a tailcfg.CapabilityVersion. +type BindUDPRelayEndpointChallenge struct { + Challenge [BindUDPRelayEndpointChallengeLen]byte +} + +func (m *BindUDPRelayEndpointChallenge) AppendMarshal(b []byte) []byte { + ret, d := appendMsgHeader(b, TypeBindUDPRelayEndpointChallenge, v0, BindUDPRelayEndpointChallengeLen) + copy(d, m.Challenge[:]) + return ret +} + +func parseBindUDPRelayEndpointChallenge(ver uint8, p []byte) (m *BindUDPRelayEndpointChallenge, err error) { + if len(p) < BindUDPRelayEndpointChallengeLen { + return nil, errShort + } + m = new(BindUDPRelayEndpointChallenge) + copy(m.Challenge[:], p[:]) + return m, nil +} + +// bindUDPRelayEndpointAnswerLen is the length of a marshalled +// BindUDPRelayEndpointAnswer message, without the message header. +const bindUDPRelayEndpointAnswerLen = BindUDPRelayEndpointChallengeLen + +// BindUDPRelayEndpointAnswer is transmitted from UDP relay client to UDP relay +// server in response to a BindUDPRelayEndpointChallenge message. This message +// type is currently considered experimental and is not yet tied to a +// tailcfg.CapabilityVersion. +type BindUDPRelayEndpointAnswer struct { + Answer [bindUDPRelayEndpointAnswerLen]byte +} + +func (m *BindUDPRelayEndpointAnswer) AppendMarshal(b []byte) []byte { + ret, d := appendMsgHeader(b, TypeBindUDPRelayEndpointAnswer, v0, bindUDPRelayEndpointAnswerLen) + copy(d, m.Answer[:]) + return ret +} + +func parseBindUDPRelayEndpointAnswer(ver uint8, p []byte) (m *BindUDPRelayEndpointAnswer, err error) { + if len(p) < bindUDPRelayEndpointAnswerLen { + return nil, errShort + } + m = new(BindUDPRelayEndpointAnswer) + copy(m.Answer[:], p[:]) + return m, nil +} diff --git a/disco/disco_test.go b/disco/disco_test.go index 1a56324a5..751190445 100644 --- a/disco/disco_test.go +++ b/disco/disco_test.go @@ -83,6 +83,29 @@ func TestMarshalAndParse(t *testing.T) { }, want: "03 00 00 00 00 00 00 00 00 00 00 00 ff ff 01 02 03 04 02 37 20 01 00 00 00 00 00 00 00 00 00 00 00 00 34 56 03 15", }, + { + name: "bind_udp_relay_endpoint", + m: &BindUDPRelayEndpoint{}, + want: "04 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00", + }, + { + name: "bind_udp_relay_endpoint_challenge", + m: &BindUDPRelayEndpointChallenge{ + Challenge: [BindUDPRelayEndpointChallengeLen]byte{ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, + }, + }, + want: "05 00 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f 10 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f", + }, + { + name: "bind_udp_relay_endpoint_answer", + m: &BindUDPRelayEndpointAnswer{ + Answer: [bindUDPRelayEndpointAnswerLen]byte{ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, + }, + }, + want: "06 00 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f 10 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/net/udprelay/server.go b/net/udprelay/server.go new file mode 100644 index 000000000..30fc08326 --- /dev/null +++ b/net/udprelay/server.go @@ -0,0 +1,532 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package udprelay contains constructs for relaying Disco and WireGuard packets +// between Tailscale clients over UDP. This package is currently considered +// experimental. +package udprelay + +import ( + "bytes" + "crypto/rand" + "errors" + "fmt" + "net" + "net/netip" + "slices" + "strconv" + "sync" + "time" + + "go4.org/mem" + "tailscale.com/disco" + "tailscale.com/net/packet" + "tailscale.com/types/key" +) + +const ( + // defaultBindLifetime is somewhat arbitrary. We attempt to account for + // high latency between client and Server, and high latency between + // clients over side channels, e.g. DERP, used to exchange ServerEndpoint + // details. So, a total of 3 paths with potentially high latency. Using a + // conservative 10s "high latency" bounds for each path we end up at a 30s + // total. It is worse to set an aggressive bind lifetime as this may lead + // to path discovery failure, vs dealing with a slight increase of Server + // resource utilization (VNIs, RAM, etc) while tracking endpoints that won't + // bind. + defaultBindLifetime = time.Second * 30 + defaultSteadyStateLifetime = time.Minute * 5 +) + +// Server implements an experimental UDP relay server. +type Server struct { + // disco keypair used as part of 3-way bind handshake + disco key.DiscoPrivate + discoPublic key.DiscoPublic + + bindLifetime time.Duration + steadyStateLifetime time.Duration + + // addrPorts contains the ip:port pairs returned as candidate server + // endpoints in response to an allocation request. + addrPorts []netip.AddrPort + + uc *net.UDPConn + + closeOnce sync.Once + wg sync.WaitGroup + closeCh chan struct{} + closed bool + + mu sync.Mutex // guards the following fields + lamportID uint64 + vniPool []uint32 // the pool of available VNIs + byVNI map[uint32]*serverEndpoint + byDisco map[pairOfDiscoPubKeys]*serverEndpoint +} + +// pairOfDiscoPubKeys is a pair of key.DiscoPublic. It must be constructed via +// newPairOfDiscoPubKeys to ensure lexicographical ordering. +type pairOfDiscoPubKeys [2]key.DiscoPublic + +func (p pairOfDiscoPubKeys) String() string { + return fmt.Sprintf("%s <=> %s", p[0].ShortString(), p[1].ShortString()) +} + +func newPairOfDiscoPubKeys(discoA, discoB key.DiscoPublic) pairOfDiscoPubKeys { + pair := pairOfDiscoPubKeys([2]key.DiscoPublic{discoA, discoB}) + slices.SortFunc(pair[:], func(a, b key.DiscoPublic) int { + return a.Compare(b) + }) + return pair +} + +// ServerEndpoint contains the Server's endpoint details. +type ServerEndpoint struct { + // ServerDisco is the Server's Disco public key used as part of the 3-way + // bind handshake. Server will use the same ServerDisco for its lifetime. + // ServerDisco value in combination with LamportID value represents a + // unique ServerEndpoint allocation. + ServerDisco key.DiscoPublic + + // LamportID is unique and monotonically non-decreasing across + // ServerEndpoint allocations for the lifetime of Server. It enables clients + // to dedup and resolve allocation event order. Clients may race to allocate + // on the same Server, and signal ServerEndpoint details via alternative + // channels, e.g. DERP. Additionally, Server.AllocateEndpoint() requests may + // not result in a new allocation depending on existing server-side endpoint + // state. Therefore, where clients have local, existing state that contains + // ServerDisco and LamportID values matching a newly learned endpoint, these + // can be considered one and the same. If ServerDisco is equal, but + // LamportID is unequal, LamportID comparison determines which + // ServerEndpoint was allocated most recently. + LamportID uint64 + + // AddrPorts are the IP:Port candidate pairs the Server may be reachable + // over. + AddrPorts []netip.AddrPort + + // VNI (Virtual Network Identifier) is the Geneve header VNI the Server + // will use for transmitted packets, and expects for received packets + // associated with this endpoint. + VNI uint32 + + // BindLifetime is amount of time post-allocation the Server will consider + // the endpoint active while it has yet to be bound via 3-way bind handshake + // from both client parties. + BindLifetime time.Duration + + // SteadyStateLifetime is the amount of time post 3-way bind handshake from + // both client parties the Server will consider the endpoint active lacking + // bidirectional data flow. + SteadyStateLifetime time.Duration +} + +// serverEndpoint contains Server-internal ServerEndpoint state. serverEndpoint +// methods are not thread-safe. +type serverEndpoint struct { + // discoPubKeys contains the key.DiscoPublic of the served clients. The + // indexing of this array aligns with the following fields, e.g. + // discoSharedSecrets[0] is the shared secret to use when sealing + // Disco protocol messages for transmission towards discoPubKeys[0]. + discoPubKeys pairOfDiscoPubKeys + discoSharedSecrets [2]key.DiscoShared + handshakeState [2]disco.BindUDPRelayHandshakeState + addrPorts [2]netip.AddrPort + lastSeen [2]time.Time // TODO(jwhited): consider using mono.Time + challenge [2][disco.BindUDPRelayEndpointChallengeLen]byte + + lamportID uint64 + vni uint32 + allocatedAt time.Time +} + +func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex int, discoMsg disco.Message, uw udpWriter, serverDisco key.DiscoPublic) { + if senderIndex != 0 && senderIndex != 1 { + return + } + handshakeState := e.handshakeState[senderIndex] + if handshakeState == disco.BindUDPRelayHandshakeStateAnswerReceived { + // this sender is already bound + return + } + switch discoMsg := discoMsg.(type) { + case *disco.BindUDPRelayEndpoint: + switch handshakeState { + case disco.BindUDPRelayHandshakeStateInit: + // set sender addr + e.addrPorts[senderIndex] = from + fallthrough + case disco.BindUDPRelayHandshakeStateChallengeSent: + if from != e.addrPorts[senderIndex] { + // this is a later arriving bind from a different source, or + // a retransmit and the sender's source has changed, discard + return + } + m := new(disco.BindUDPRelayEndpointChallenge) + copy(m.Challenge[:], e.challenge[senderIndex][:]) + reply := make([]byte, packet.GeneveFixedHeaderLength, 512) + gh := packet.GeneveHeader{Control: true, VNI: e.vni, Protocol: packet.GeneveProtocolDisco} + err := gh.Encode(reply) + if err != nil { + return + } + reply = append(reply, disco.Magic...) + reply = serverDisco.AppendTo(reply) + box := e.discoSharedSecrets[senderIndex].Seal(m.AppendMarshal(nil)) + reply = append(reply, box...) + uw.WriteMsgUDPAddrPort(reply, nil, from) + // set new state + e.handshakeState[senderIndex] = disco.BindUDPRelayHandshakeStateChallengeSent + return + default: + // disco.BindUDPRelayEndpoint is unexpected in all other handshake states + return + } + case *disco.BindUDPRelayEndpointAnswer: + switch handshakeState { + case disco.BindUDPRelayHandshakeStateChallengeSent: + if from != e.addrPorts[senderIndex] { + // sender source has changed + return + } + if !bytes.Equal(discoMsg.Answer[:], e.challenge[senderIndex][:]) { + // bad answer + return + } + // sender is now bound + // TODO: Consider installing a fast path via netfilter or similar to + // relay (NAT) data packets for this serverEndpoint. + e.handshakeState[senderIndex] = disco.BindUDPRelayHandshakeStateAnswerReceived + // record last seen as bound time + e.lastSeen[senderIndex] = time.Now() + return + default: + // disco.BindUDPRelayEndpointAnswer is unexpected in all other handshake + // states, or we've already handled it + return + } + default: + // unexpected Disco message type + return + } +} + +func (e *serverEndpoint) handleSealedDiscoControlMsg(from netip.AddrPort, b []byte, uw udpWriter, serverDisco key.DiscoPublic) { + senderRaw, isDiscoMsg := disco.Source(b) + if !isDiscoMsg { + // Not a Disco message + return + } + sender := key.DiscoPublicFromRaw32(mem.B(senderRaw)) + senderIndex := -1 + switch { + case sender.Compare(e.discoPubKeys[0]) == 0: + senderIndex = 0 + case sender.Compare(e.discoPubKeys[1]) == 0: + senderIndex = 1 + default: + // unknown Disco public key + return + } + + const headerLen = len(disco.Magic) + key.DiscoPublicRawLen + discoPayload, ok := e.discoSharedSecrets[senderIndex].Open(b[headerLen:]) + if !ok { + // unable to decrypt the Disco payload + return + } + + discoMsg, err := disco.Parse(discoPayload) + if err != nil { + // unable to parse the Disco payload + return + } + + e.handleDiscoControlMsg(from, senderIndex, discoMsg, uw, serverDisco) +} + +type udpWriter interface { + WriteMsgUDPAddrPort(b []byte, oob []byte, addr netip.AddrPort) (n, oobn int, err error) +} + +func (e *serverEndpoint) handlePacket(from netip.AddrPort, gh packet.GeneveHeader, b []byte, uw udpWriter, serverDisco key.DiscoPublic) { + if !gh.Control { + if !e.isBound() { + // not a control packet, but serverEndpoint isn't bound + return + } + var to netip.AddrPort + switch { + case from == e.addrPorts[0]: + e.lastSeen[0] = time.Now() + to = e.addrPorts[1] + case from == e.addrPorts[1]: + e.lastSeen[1] = time.Now() + to = e.addrPorts[0] + default: + // unrecognized source + return + } + // relay packet + uw.WriteMsgUDPAddrPort(b, nil, to) + return + } + + if e.isBound() { + // control packet, but serverEndpoint is already bound + return + } + + if gh.Protocol != packet.GeneveProtocolDisco { + // control packet, but not Disco + return + } + + msg := b[packet.GeneveFixedHeaderLength:] + e.handleSealedDiscoControlMsg(from, msg, uw, serverDisco) +} + +func (e *serverEndpoint) isExpired(now time.Time, bindLifetime, steadyStateLifetime time.Duration) bool { + if !e.isBound() { + if now.Sub(e.allocatedAt) > bindLifetime { + return true + } + return false + } + if now.Sub(e.lastSeen[0]) > steadyStateLifetime || now.Sub(e.lastSeen[1]) > steadyStateLifetime { + return true + } + return false +} + +// isBound returns true if both clients have completed their 3-way handshake, +// otherwise false. +func (e *serverEndpoint) isBound() bool { + return e.handshakeState[0] == disco.BindUDPRelayHandshakeStateAnswerReceived && + e.handshakeState[1] == disco.BindUDPRelayHandshakeStateAnswerReceived +} + +// NewServer constructs a Server listening on 0.0.0.0:'port'. IPv6 is not yet +// supported. Port may be 0, and what ultimately gets bound is returned as +// 'boundPort'. Supplied 'addrs' are joined with 'boundPort' and returned as +// ServerEndpoint.AddrPorts in response to Server.AllocateEndpoint() requests. +// +// TODO: IPv6 support +// TODO: dynamic addrs:port discovery +func NewServer(port int, addrs []netip.Addr) (s *Server, boundPort int, err error) { + s = &Server{ + disco: key.NewDisco(), + bindLifetime: defaultBindLifetime, + steadyStateLifetime: defaultSteadyStateLifetime, + closeCh: make(chan struct{}), + byDisco: make(map[pairOfDiscoPubKeys]*serverEndpoint), + byVNI: make(map[uint32]*serverEndpoint), + } + s.discoPublic = s.disco.Public() + // TODO: instead of allocating 10s of MBs for the full pool, allocate + // smaller chunks and increase as needed + s.vniPool = make([]uint32, 0, 1<<24-1) + for i := 1; i < 1<<24; i++ { + s.vniPool = append(s.vniPool, uint32(i)) + } + boundPort, err = s.listenOn(port) + if err != nil { + return nil, 0, err + } + addrPorts := make([]netip.AddrPort, 0, len(addrs)) + for _, addr := range addrs { + addrPort, err := netip.ParseAddrPort(net.JoinHostPort(addr.String(), strconv.Itoa(boundPort))) + if err != nil { + return nil, 0, err + } + addrPorts = append(addrPorts, addrPort) + } + s.addrPorts = addrPorts + s.wg.Add(2) + go s.packetReadLoop() + go s.endpointGCLoop() + return s, boundPort, nil +} + +func (s *Server) listenOn(port int) (int, error) { + uc, err := net.ListenUDP("udp4", &net.UDPAddr{Port: port}) + if err != nil { + return 0, err + } + // TODO: set IP_PKTINFO sockopt + _, boundPortStr, err := net.SplitHostPort(uc.LocalAddr().String()) + if err != nil { + s.uc.Close() + return 0, err + } + boundPort, err := strconv.Atoi(boundPortStr) + if err != nil { + s.uc.Close() + return 0, err + } + s.uc = uc + return boundPort, nil +} + +// Close closes the server. +func (s *Server) Close() error { + s.closeOnce.Do(func() { + s.mu.Lock() + defer s.mu.Unlock() + s.uc.Close() + close(s.closeCh) + s.wg.Wait() + clear(s.byVNI) + clear(s.byDisco) + s.vniPool = nil + s.closed = true + }) + return nil +} + +func (s *Server) endpointGCLoop() { + defer s.wg.Done() + ticker := time.NewTicker(s.bindLifetime) + defer ticker.Stop() + + gc := func() { + now := time.Now() + // TODO: consider performance implications of scanning all endpoints and + // holding s.mu for the duration. Keep it simple (and slow) for now. + s.mu.Lock() + defer s.mu.Unlock() + for k, v := range s.byDisco { + if v.isExpired(now, s.bindLifetime, s.steadyStateLifetime) { + delete(s.byDisco, k) + delete(s.byVNI, v.vni) + s.vniPool = append(s.vniPool, v.vni) + } + } + } + + for { + select { + case <-ticker.C: + gc() + case <-s.closeCh: + return + } + } +} + +func (s *Server) handlePacket(from netip.AddrPort, b []byte, uw udpWriter) { + gh := packet.GeneveHeader{} + err := gh.Decode(b) + if err != nil { + return + } + // TODO: consider performance implications of holding s.mu for the remainder + // of this method, which does a bunch of disco/crypto work depending. Keep + // it simple (and slow) for now. + s.mu.Lock() + defer s.mu.Unlock() + e, ok := s.byVNI[gh.VNI] + if !ok { + // unknown VNI + return + } + + e.handlePacket(from, gh, b, uw, s.discoPublic) +} + +func (s *Server) packetReadLoop() { + defer func() { + s.wg.Done() + s.Close() + }() + b := make([]byte, 1<<16-1) + for { + // TODO: extract laddr from IP_PKTINFO for use in reply + n, from, err := s.uc.ReadFromUDPAddrPort(b) + if err != nil { + return + } + s.handlePacket(from, b[:n], s.uc) + } +} + +var ErrServerClosed = errors.New("server closed") + +// AllocateEndpoint allocates a ServerEndpoint for the provided pair of +// key.DiscoPublic's. It returns an error (ErrServerClosed) if the server has +// been closed. +func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (ServerEndpoint, error) { + s.mu.Lock() + defer s.mu.Unlock() + if s.closed { + return ServerEndpoint{}, ErrServerClosed + } + + if discoA.Compare(s.discoPublic) == 0 || discoB.Compare(s.discoPublic) == 0 { + return ServerEndpoint{}, fmt.Errorf("client disco equals server disco: %s", s.discoPublic.ShortString()) + } + + pair := newPairOfDiscoPubKeys(discoA, discoB) + e, ok := s.byDisco[pair] + if ok { + if !e.isBound() { + // If the endpoint is not yet bound this is likely an allocation + // race between two clients on the same Server. Instead of + // re-allocating we return the existing allocation. We do not reset + // e.allocatedAt in case a client is "stuck" in an allocation + // loop and will not be able to complete a handshake, for whatever + // reason. Once the endpoint expires a new endpoint will be + // allocated. Clients can resolve duplicate ServerEndpoint details + // via ServerEndpoint.LamportID. + // + // TODO: consider ServerEndpoint.BindLifetime -= time.Now()-e.allocatedAt + // to give the client a more accurate picture of the bind window. + // Or, some threshold to trigger re-allocation if too much time has + // already passed since it was originally allocated. + return ServerEndpoint{ + ServerDisco: s.discoPublic, + AddrPorts: s.addrPorts, + VNI: e.vni, + LamportID: e.lamportID, + BindLifetime: s.bindLifetime, + SteadyStateLifetime: s.steadyStateLifetime, + }, nil + } + // If an endpoint exists for the pair of key.DiscoPublic's, and is + // already bound, delete it. We will re-allocate a new endpoint. Chances + // are clients cannot make use of the existing, bound allocation if + // they are requesting a new one. + delete(s.byDisco, pair) + delete(s.byVNI, e.vni) + s.vniPool = append(s.vniPool, e.vni) + } + + if len(s.vniPool) == 0 { + return ServerEndpoint{}, errors.New("VNI pool exhausted") + } + + s.lamportID++ + e = &serverEndpoint{ + discoPubKeys: pair, + lamportID: s.lamportID, + allocatedAt: time.Now(), + } + e.discoSharedSecrets[0] = s.disco.Shared(e.discoPubKeys[0]) + e.discoSharedSecrets[1] = s.disco.Shared(e.discoPubKeys[1]) + e.vni, s.vniPool = s.vniPool[0], s.vniPool[1:] + rand.Read(e.challenge[0][:]) + rand.Read(e.challenge[1][:]) + + s.byDisco[pair] = e + s.byVNI[e.vni] = e + + return ServerEndpoint{ + ServerDisco: s.discoPublic, + AddrPorts: s.addrPorts, + VNI: e.vni, + LamportID: e.lamportID, + BindLifetime: s.bindLifetime, + SteadyStateLifetime: s.steadyStateLifetime, + }, nil +} diff --git a/net/udprelay/server_test.go b/net/udprelay/server_test.go new file mode 100644 index 000000000..733e50b77 --- /dev/null +++ b/net/udprelay/server_test.go @@ -0,0 +1,204 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package udprelay + +import ( + "bytes" + "net" + "net/netip" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "go4.org/mem" + "tailscale.com/disco" + "tailscale.com/net/packet" + "tailscale.com/types/key" +) + +type testClient struct { + vni uint32 + local key.DiscoPrivate + server key.DiscoPublic + uc *net.UDPConn +} + +func newTestClient(t *testing.T, vni uint32, serverEndpoint netip.AddrPort, local key.DiscoPrivate, server key.DiscoPublic) *testClient { + rAddr := &net.UDPAddr{IP: serverEndpoint.Addr().AsSlice(), Port: int(serverEndpoint.Port())} + uc, err := net.DialUDP("udp4", nil, rAddr) + if err != nil { + t.Fatal(err) + } + return &testClient{ + vni: vni, + local: local, + server: server, + uc: uc, + } +} + +func (c *testClient) write(t *testing.T, b []byte) { + _, err := c.uc.Write(b) + if err != nil { + t.Fatal(err) + } +} + +func (c *testClient) read(t *testing.T) []byte { + c.uc.SetReadDeadline(time.Now().Add(time.Second)) + b := make([]byte, 1<<16-1) + n, err := c.uc.Read(b) + if err != nil { + t.Fatal(err) + } + return b[:n] +} + +func (c *testClient) writeDataPkt(t *testing.T, b []byte) { + pkt := make([]byte, packet.GeneveFixedHeaderLength, packet.GeneveFixedHeaderLength+len(b)) + gh := packet.GeneveHeader{Control: false, VNI: c.vni, Protocol: packet.GeneveProtocolWireGuard} + err := gh.Encode(pkt) + if err != nil { + t.Fatal(err) + } + pkt = append(pkt, b...) + c.write(t, pkt) +} + +func (c *testClient) readDataPkt(t *testing.T) []byte { + b := c.read(t) + gh := packet.GeneveHeader{} + err := gh.Decode(b) + if err != nil { + t.Fatal(err) + } + if gh.Protocol != packet.GeneveProtocolWireGuard { + t.Fatal("unexpected geneve protocol") + } + if gh.Control { + t.Fatal("unexpected control") + } + if gh.VNI != c.vni { + t.Fatal("unexpected vni") + } + return b[packet.GeneveFixedHeaderLength:] +} + +func (c *testClient) writeControlDiscoMsg(t *testing.T, msg disco.Message) { + pkt := make([]byte, packet.GeneveFixedHeaderLength, 512) + gh := packet.GeneveHeader{Control: true, VNI: c.vni, Protocol: packet.GeneveProtocolDisco} + err := gh.Encode(pkt) + if err != nil { + t.Fatal(err) + } + pkt = append(pkt, disco.Magic...) + pkt = c.local.Public().AppendTo(pkt) + box := c.local.Shared(c.server).Seal(msg.AppendMarshal(nil)) + pkt = append(pkt, box...) + c.write(t, pkt) +} + +func (c *testClient) readControlDiscoMsg(t *testing.T) disco.Message { + b := c.read(t) + gh := packet.GeneveHeader{} + err := gh.Decode(b) + if err != nil { + t.Fatal(err) + } + if gh.Protocol != packet.GeneveProtocolDisco { + t.Fatal("unexpected geneve protocol") + } + if !gh.Control { + t.Fatal("unexpected non-control") + } + if gh.VNI != c.vni { + t.Fatal("unexpected vni") + } + b = b[packet.GeneveFixedHeaderLength:] + headerLen := len(disco.Magic) + key.DiscoPublicRawLen + if len(b) < headerLen { + t.Fatal("disco message too short") + } + sender := key.DiscoPublicFromRaw32(mem.B(b[len(disco.Magic):headerLen])) + if sender.Compare(c.server) != 0 { + t.Fatal("unknown disco public key") + } + payload, ok := c.local.Shared(c.server).Open(b[headerLen:]) + if !ok { + t.Fatal("failed to open sealed disco msg") + } + msg, err := disco.Parse(payload) + if err != nil { + t.Fatal("failed to parse disco payload") + } + return msg +} + +func (c *testClient) handshake(t *testing.T) { + c.writeControlDiscoMsg(t, &disco.BindUDPRelayEndpoint{}) + msg := c.readControlDiscoMsg(t) + challenge, ok := msg.(*disco.BindUDPRelayEndpointChallenge) + if !ok { + t.Fatal("unexepcted disco message type") + } + c.writeControlDiscoMsg(t, &disco.BindUDPRelayEndpointAnswer{Answer: challenge.Challenge}) +} + +func (c *testClient) close() { + c.uc.Close() +} + +func TestServer(t *testing.T) { + discoA := key.NewDisco() + discoB := key.NewDisco() + + ipv4LoopbackAddr := netip.MustParseAddr("127.0.0.1") + + server, _, err := NewServer(0, []netip.Addr{ipv4LoopbackAddr}) + if err != nil { + t.Fatal(err) + } + defer server.Close() + + endpoint, err := server.AllocateEndpoint(discoA.Public(), discoB.Public()) + if err != nil { + t.Fatal(err) + } + dupEndpoint, err := server.AllocateEndpoint(discoA.Public(), discoB.Public()) + if err != nil { + t.Fatal(err) + } + + // We expect the same endpoint details as the 3-way bind handshake has not + // yet been completed for both relay client parties. + if diff := cmp.Diff(dupEndpoint, endpoint, cmpopts.EquateComparable(netip.AddrPort{}, key.DiscoPublic{})); diff != "" { + t.Fatalf("wrong dupEndpoint (-got +want)\n%s", diff) + } + + if len(endpoint.AddrPorts) != 1 { + t.Fatalf("unexpected endpoint.AddrPorts: %v", endpoint.AddrPorts) + } + tcA := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoA, endpoint.ServerDisco) + defer tcA.close() + tcB := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoB, endpoint.ServerDisco) + defer tcB.close() + + tcA.handshake(t) + tcB.handshake(t) + + txToB := []byte{1, 2, 3} + tcA.writeDataPkt(t, txToB) + rxFromA := tcB.readDataPkt(t) + if !bytes.Equal(txToB, rxFromA) { + t.Fatal("unexpected msg A->B") + } + + txToA := []byte{4, 5, 6} + tcB.writeDataPkt(t, txToA) + rxFromB := tcA.readDataPkt(t) + if !bytes.Equal(txToA, rxFromB) { + t.Fatal("unexpected msg B->A") + } +}