// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause

// Package udprelay contains constructs for relaying Disco and WireGuard packets
// between Tailscale clients over UDP. This package is currently considered
// experimental.
package udprelay

import (
	"bytes"
	"crypto/rand"
	"errors"
	"fmt"
	"net"
	"net/netip"
	"slices"
	"strconv"
	"sync"
	"time"

	"go4.org/mem"
	"tailscale.com/disco"
	"tailscale.com/net/packet"
	"tailscale.com/tstime"
	"tailscale.com/types/key"
)

const (
	// defaultBindLifetime is somewhat arbitrary. We attempt to account for
	// high latency between client and Server, and high latency between
	// clients over side channels, e.g. DERP, used to exchange ServerEndpoint
	// details. So, a total of 3 paths with potentially high latency. Using a
	// conservative 10s "high latency" bounds for each path we end up at a 30s
	// total. It is worse to set an aggressive bind lifetime as this may lead
	// to path discovery failure, vs dealing with a slight increase of Server
	// resource utilization (VNIs, RAM, etc) while tracking endpoints that won't
	// bind.
	defaultBindLifetime        = time.Second * 30
	defaultSteadyStateLifetime = time.Minute * 5
)

// Server implements an experimental UDP relay server.
type Server struct {
	// disco keypair used as part of 3-way bind handshake
	disco       key.DiscoPrivate
	discoPublic key.DiscoPublic

	bindLifetime        time.Duration
	steadyStateLifetime time.Duration

	// addrPorts contains the ip:port pairs returned as candidate server
	// endpoints in response to an allocation request.
	addrPorts []netip.AddrPort

	uc *net.UDPConn

	closeOnce sync.Once
	wg        sync.WaitGroup
	closeCh   chan struct{}
	closed    bool

	mu        sync.Mutex // guards the following fields
	lamportID uint64
	vniPool   []uint32 // the pool of available VNIs
	byVNI     map[uint32]*serverEndpoint
	byDisco   map[pairOfDiscoPubKeys]*serverEndpoint
}

// pairOfDiscoPubKeys is a pair of key.DiscoPublic. It must be constructed via
// newPairOfDiscoPubKeys to ensure lexicographical ordering.
type pairOfDiscoPubKeys [2]key.DiscoPublic

func (p pairOfDiscoPubKeys) String() string {
	return fmt.Sprintf("%s <=> %s", p[0].ShortString(), p[1].ShortString())
}

func newPairOfDiscoPubKeys(discoA, discoB key.DiscoPublic) pairOfDiscoPubKeys {
	pair := pairOfDiscoPubKeys([2]key.DiscoPublic{discoA, discoB})
	slices.SortFunc(pair[:], func(a, b key.DiscoPublic) int {
		return a.Compare(b)
	})
	return pair
}

// ServerEndpoint contains the Server's endpoint details.
type ServerEndpoint struct {
	// ServerDisco is the Server's Disco public key used as part of the 3-way
	// bind handshake. Server will use the same ServerDisco for its lifetime.
	// ServerDisco value in combination with LamportID value represents a
	// unique ServerEndpoint allocation.
	ServerDisco key.DiscoPublic

	// LamportID is unique and monotonically non-decreasing across
	// ServerEndpoint allocations for the lifetime of Server. It enables clients
	// to dedup and resolve allocation event order. Clients may race to allocate
	// on the same Server, and signal ServerEndpoint details via alternative
	// channels, e.g. DERP. Additionally, Server.AllocateEndpoint() requests may
	// not result in a new allocation depending on existing server-side endpoint
	// state. Therefore, where clients have local, existing state that contains
	// ServerDisco and LamportID values matching a newly learned endpoint, these
	// can be considered one and the same. If ServerDisco is equal, but
	// LamportID is unequal, LamportID comparison determines which
	// ServerEndpoint was allocated most recently.
	LamportID uint64

	// AddrPorts are the IP:Port candidate pairs the Server may be reachable
	// over.
	AddrPorts []netip.AddrPort

	// VNI (Virtual Network Identifier) is the Geneve header VNI the Server
	// will use for transmitted packets, and expects for received packets
	// associated with this endpoint.
	VNI uint32

	// BindLifetime is amount of time post-allocation the Server will consider
	// the endpoint active while it has yet to be bound via 3-way bind handshake
	// from both client parties.
	BindLifetime tstime.GoDuration

	// SteadyStateLifetime is the amount of time post 3-way bind handshake from
	// both client parties the Server will consider the endpoint active lacking
	// bidirectional data flow.
	SteadyStateLifetime tstime.GoDuration
}

// serverEndpoint contains Server-internal ServerEndpoint state. serverEndpoint
// methods are not thread-safe.
type serverEndpoint struct {
	// discoPubKeys contains the key.DiscoPublic of the served clients. The
	// indexing of this array aligns with the following fields, e.g.
	// discoSharedSecrets[0] is the shared secret to use when sealing
	// Disco protocol messages for transmission towards discoPubKeys[0].
	discoPubKeys       pairOfDiscoPubKeys
	discoSharedSecrets [2]key.DiscoShared
	handshakeState     [2]disco.BindUDPRelayHandshakeState
	addrPorts          [2]netip.AddrPort
	lastSeen           [2]time.Time // TODO(jwhited): consider using mono.Time
	challenge          [2][disco.BindUDPRelayEndpointChallengeLen]byte

	lamportID   uint64
	vni         uint32
	allocatedAt time.Time
}

func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex int, discoMsg disco.Message, uw udpWriter, serverDisco key.DiscoPublic) {
	if senderIndex != 0 && senderIndex != 1 {
		return
	}
	handshakeState := e.handshakeState[senderIndex]
	if handshakeState == disco.BindUDPRelayHandshakeStateAnswerReceived {
		// this sender is already bound
		return
	}
	switch discoMsg := discoMsg.(type) {
	case *disco.BindUDPRelayEndpoint:
		switch handshakeState {
		case disco.BindUDPRelayHandshakeStateInit:
			// set sender addr
			e.addrPorts[senderIndex] = from
			fallthrough
		case disco.BindUDPRelayHandshakeStateChallengeSent:
			if from != e.addrPorts[senderIndex] {
				// this is a later arriving bind from a different source, or
				// a retransmit and the sender's source has changed, discard
				return
			}
			m := new(disco.BindUDPRelayEndpointChallenge)
			copy(m.Challenge[:], e.challenge[senderIndex][:])
			reply := make([]byte, packet.GeneveFixedHeaderLength, 512)
			gh := packet.GeneveHeader{Control: true, VNI: e.vni, Protocol: packet.GeneveProtocolDisco}
			err := gh.Encode(reply)
			if err != nil {
				return
			}
			reply = append(reply, disco.Magic...)
			reply = serverDisco.AppendTo(reply)
			box := e.discoSharedSecrets[senderIndex].Seal(m.AppendMarshal(nil))
			reply = append(reply, box...)
			uw.WriteMsgUDPAddrPort(reply, nil, from)
			// set new state
			e.handshakeState[senderIndex] = disco.BindUDPRelayHandshakeStateChallengeSent
			return
		default:
			// disco.BindUDPRelayEndpoint is unexpected in all other handshake states
			return
		}
	case *disco.BindUDPRelayEndpointAnswer:
		switch handshakeState {
		case disco.BindUDPRelayHandshakeStateChallengeSent:
			if from != e.addrPorts[senderIndex] {
				// sender source has changed
				return
			}
			if !bytes.Equal(discoMsg.Answer[:], e.challenge[senderIndex][:]) {
				// bad answer
				return
			}
			// sender is now bound
			// TODO: Consider installing a fast path via netfilter or similar to
			// relay (NAT) data packets for this serverEndpoint.
			e.handshakeState[senderIndex] = disco.BindUDPRelayHandshakeStateAnswerReceived
			// record last seen as bound time
			e.lastSeen[senderIndex] = time.Now()
			return
		default:
			// disco.BindUDPRelayEndpointAnswer is unexpected in all other handshake
			// states, or we've already handled it
			return
		}
	default:
		// unexpected Disco message type
		return
	}
}

func (e *serverEndpoint) handleSealedDiscoControlMsg(from netip.AddrPort, b []byte, uw udpWriter, serverDisco key.DiscoPublic) {
	senderRaw, isDiscoMsg := disco.Source(b)
	if !isDiscoMsg {
		// Not a Disco message
		return
	}
	sender := key.DiscoPublicFromRaw32(mem.B(senderRaw))
	senderIndex := -1
	switch {
	case sender.Compare(e.discoPubKeys[0]) == 0:
		senderIndex = 0
	case sender.Compare(e.discoPubKeys[1]) == 0:
		senderIndex = 1
	default:
		// unknown Disco public key
		return
	}

	const headerLen = len(disco.Magic) + key.DiscoPublicRawLen
	discoPayload, ok := e.discoSharedSecrets[senderIndex].Open(b[headerLen:])
	if !ok {
		// unable to decrypt the Disco payload
		return
	}

	discoMsg, err := disco.Parse(discoPayload)
	if err != nil {
		// unable to parse the Disco payload
		return
	}

	e.handleDiscoControlMsg(from, senderIndex, discoMsg, uw, serverDisco)
}

type udpWriter interface {
	WriteMsgUDPAddrPort(b []byte, oob []byte, addr netip.AddrPort) (n, oobn int, err error)
}

func (e *serverEndpoint) handlePacket(from netip.AddrPort, gh packet.GeneveHeader, b []byte, uw udpWriter, serverDisco key.DiscoPublic) {
	if !gh.Control {
		if !e.isBound() {
			// not a control packet, but serverEndpoint isn't bound
			return
		}
		var to netip.AddrPort
		switch {
		case from == e.addrPorts[0]:
			e.lastSeen[0] = time.Now()
			to = e.addrPorts[1]
		case from == e.addrPorts[1]:
			e.lastSeen[1] = time.Now()
			to = e.addrPorts[0]
		default:
			// unrecognized source
			return
		}
		// relay packet
		uw.WriteMsgUDPAddrPort(b, nil, to)
		return
	}

	if e.isBound() {
		// control packet, but serverEndpoint is already bound
		return
	}

	if gh.Protocol != packet.GeneveProtocolDisco {
		// control packet, but not Disco
		return
	}

	msg := b[packet.GeneveFixedHeaderLength:]
	e.handleSealedDiscoControlMsg(from, msg, uw, serverDisco)
}

func (e *serverEndpoint) isExpired(now time.Time, bindLifetime, steadyStateLifetime time.Duration) bool {
	if !e.isBound() {
		if now.Sub(e.allocatedAt) > bindLifetime {
			return true
		}
		return false
	}
	if now.Sub(e.lastSeen[0]) > steadyStateLifetime || now.Sub(e.lastSeen[1]) > steadyStateLifetime {
		return true
	}
	return false
}

// isBound returns true if both clients have completed their 3-way handshake,
// otherwise false.
func (e *serverEndpoint) isBound() bool {
	return e.handshakeState[0] == disco.BindUDPRelayHandshakeStateAnswerReceived &&
		e.handshakeState[1] == disco.BindUDPRelayHandshakeStateAnswerReceived
}

// NewServer constructs a Server listening on 0.0.0.0:'port'. IPv6 is not yet
// supported. Port may be 0, and what ultimately gets bound is returned as
// 'boundPort'. Supplied 'addrs' are joined with 'boundPort' and returned as
// ServerEndpoint.AddrPorts in response to Server.AllocateEndpoint() requests.
//
// TODO: IPv6 support
// TODO: dynamic addrs:port discovery
func NewServer(port int, addrs []netip.Addr) (s *Server, boundPort int, err error) {
	s = &Server{
		disco:               key.NewDisco(),
		bindLifetime:        defaultBindLifetime,
		steadyStateLifetime: defaultSteadyStateLifetime,
		closeCh:             make(chan struct{}),
		byDisco:             make(map[pairOfDiscoPubKeys]*serverEndpoint),
		byVNI:               make(map[uint32]*serverEndpoint),
	}
	s.discoPublic = s.disco.Public()
	// TODO: instead of allocating 10s of MBs for the full pool, allocate
	// smaller chunks and increase as needed
	s.vniPool = make([]uint32, 0, 1<<24-1)
	for i := 1; i < 1<<24; i++ {
		s.vniPool = append(s.vniPool, uint32(i))
	}
	boundPort, err = s.listenOn(port)
	if err != nil {
		return nil, 0, err
	}
	addrPorts := make([]netip.AddrPort, 0, len(addrs))
	for _, addr := range addrs {
		addrPort, err := netip.ParseAddrPort(net.JoinHostPort(addr.String(), strconv.Itoa(boundPort)))
		if err != nil {
			return nil, 0, err
		}
		addrPorts = append(addrPorts, addrPort)
	}
	s.addrPorts = addrPorts
	s.wg.Add(2)
	go s.packetReadLoop()
	go s.endpointGCLoop()
	return s, boundPort, nil
}

func (s *Server) listenOn(port int) (int, error) {
	uc, err := net.ListenUDP("udp4", &net.UDPAddr{Port: port})
	if err != nil {
		return 0, err
	}
	// TODO: set IP_PKTINFO sockopt
	_, boundPortStr, err := net.SplitHostPort(uc.LocalAddr().String())
	if err != nil {
		s.uc.Close()
		return 0, err
	}
	boundPort, err := strconv.Atoi(boundPortStr)
	if err != nil {
		s.uc.Close()
		return 0, err
	}
	s.uc = uc
	return boundPort, nil
}

// Close closes the server.
func (s *Server) Close() error {
	s.closeOnce.Do(func() {
		s.mu.Lock()
		defer s.mu.Unlock()
		s.uc.Close()
		close(s.closeCh)
		s.wg.Wait()
		clear(s.byVNI)
		clear(s.byDisco)
		s.vniPool = nil
		s.closed = true
	})
	return nil
}

func (s *Server) endpointGCLoop() {
	defer s.wg.Done()
	ticker := time.NewTicker(s.bindLifetime)
	defer ticker.Stop()

	gc := func() {
		now := time.Now()
		// TODO: consider performance implications of scanning all endpoints and
		// holding s.mu for the duration. Keep it simple (and slow) for now.
		s.mu.Lock()
		defer s.mu.Unlock()
		for k, v := range s.byDisco {
			if v.isExpired(now, s.bindLifetime, s.steadyStateLifetime) {
				delete(s.byDisco, k)
				delete(s.byVNI, v.vni)
				s.vniPool = append(s.vniPool, v.vni)
			}
		}
	}

	for {
		select {
		case <-ticker.C:
			gc()
		case <-s.closeCh:
			return
		}
	}
}

func (s *Server) handlePacket(from netip.AddrPort, b []byte, uw udpWriter) {
	gh := packet.GeneveHeader{}
	err := gh.Decode(b)
	if err != nil {
		return
	}
	// TODO: consider performance implications of holding s.mu for the remainder
	// of this method, which does a bunch of disco/crypto work depending. Keep
	// it simple (and slow) for now.
	s.mu.Lock()
	defer s.mu.Unlock()
	e, ok := s.byVNI[gh.VNI]
	if !ok {
		// unknown VNI
		return
	}

	e.handlePacket(from, gh, b, uw, s.discoPublic)
}

func (s *Server) packetReadLoop() {
	defer func() {
		s.wg.Done()
		s.Close()
	}()
	b := make([]byte, 1<<16-1)
	for {
		// TODO: extract laddr from IP_PKTINFO for use in reply
		n, from, err := s.uc.ReadFromUDPAddrPort(b)
		if err != nil {
			return
		}
		s.handlePacket(from, b[:n], s.uc)
	}
}

var ErrServerClosed = errors.New("server closed")

// AllocateEndpoint allocates a ServerEndpoint for the provided pair of
// key.DiscoPublic's. It returns an error (ErrServerClosed) if the server has
// been closed.
func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (ServerEndpoint, error) {
	s.mu.Lock()
	defer s.mu.Unlock()
	if s.closed {
		return ServerEndpoint{}, ErrServerClosed
	}

	if discoA.Compare(s.discoPublic) == 0 || discoB.Compare(s.discoPublic) == 0 {
		return ServerEndpoint{}, fmt.Errorf("client disco equals server disco: %s", s.discoPublic.ShortString())
	}

	pair := newPairOfDiscoPubKeys(discoA, discoB)
	e, ok := s.byDisco[pair]
	if ok {
		if !e.isBound() {
			// If the endpoint is not yet bound this is likely an allocation
			// race between two clients on the same Server. Instead of
			// re-allocating we return the existing allocation. We do not reset
			// e.allocatedAt in case a client is "stuck" in an allocation
			// loop and will not be able to complete a handshake, for whatever
			// reason. Once the endpoint expires a new endpoint will be
			// allocated. Clients can resolve duplicate ServerEndpoint details
			// via ServerEndpoint.LamportID.
			//
			// TODO: consider ServerEndpoint.BindLifetime -= time.Now()-e.allocatedAt
			// to give the client a more accurate picture of the bind window.
			// Or, some threshold to trigger re-allocation if too much time has
			// already passed since it was originally allocated.
			return ServerEndpoint{
				ServerDisco:         s.discoPublic,
				AddrPorts:           s.addrPorts,
				VNI:                 e.vni,
				LamportID:           e.lamportID,
				BindLifetime:        tstime.GoDuration{Duration: s.bindLifetime},
				SteadyStateLifetime: tstime.GoDuration{Duration: s.steadyStateLifetime},
			}, nil
		}
		// If an endpoint exists for the pair of key.DiscoPublic's, and is
		// already bound, delete it. We will re-allocate a new endpoint. Chances
		// are clients cannot make use of the existing, bound allocation if
		// they are requesting a new one.
		delete(s.byDisco, pair)
		delete(s.byVNI, e.vni)
		s.vniPool = append(s.vniPool, e.vni)
	}

	if len(s.vniPool) == 0 {
		return ServerEndpoint{}, errors.New("VNI pool exhausted")
	}

	s.lamportID++
	e = &serverEndpoint{
		discoPubKeys: pair,
		lamportID:    s.lamportID,
		allocatedAt:  time.Now(),
	}
	e.discoSharedSecrets[0] = s.disco.Shared(e.discoPubKeys[0])
	e.discoSharedSecrets[1] = s.disco.Shared(e.discoPubKeys[1])
	e.vni, s.vniPool = s.vniPool[0], s.vniPool[1:]
	rand.Read(e.challenge[0][:])
	rand.Read(e.challenge[1][:])

	s.byDisco[pair] = e
	s.byVNI[e.vni] = e

	return ServerEndpoint{
		ServerDisco:         s.discoPublic,
		AddrPorts:           s.addrPorts,
		VNI:                 e.vni,
		LamportID:           e.lamportID,
		BindLifetime:        tstime.GoDuration{Duration: s.bindLifetime},
		SteadyStateLifetime: tstime.GoDuration{Duration: s.steadyStateLifetime},
	}, nil
}