// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

// Package testcontrol contains a minimal control plane server for testing purposes.
package testcontrol

import (
	"bytes"
	"context"
	crand "crypto/rand"
	"encoding/binary"
	"encoding/json"
	"errors"
	"fmt"
	"io"
	"io/ioutil"
	"log"
	"math/rand"
	"net/http"
	"net/http/httptest"
	"net/url"
	"sort"
	"strings"
	"sync"
	"time"

	"github.com/klauspost/compress/zstd"
	"go4.org/mem"
	"inet.af/netaddr"
	"tailscale.com/net/tsaddr"
	"tailscale.com/smallzstd"
	"tailscale.com/tailcfg"
	"tailscale.com/types/key"
	"tailscale.com/types/logger"
)

const msgLimit = 1 << 20 // encrypted message length limit

// Server is a control plane server. Its zero value is ready for use.
// Everything is stored in-memory in one tailnet.
type Server struct {
	Logf        logger.Logf      // nil means to use the log package
	DERPMap     *tailcfg.DERPMap // nil means to use prod DERP map
	RequireAuth bool
	Verbose     bool
	DNSConfig   *tailcfg.DNSConfig // nil means no DNS config

	// ExplicitBaseURL or HTTPTestServer must be set.
	ExplicitBaseURL string           // e.g. "http://127.0.0.1:1234" with no trailing URL
	HTTPTestServer  *httptest.Server // if non-nil, used to get BaseURL

	initMuxOnce sync.Once
	mux         *http.ServeMux

	mu            sync.Mutex
	inServeMap    int
	cond          *sync.Cond // lazily initialized by condLocked
	pubKey        key.MachinePublic
	privKey       key.ControlPrivate // not strictly needed vs. MachinePrivate, but handy to test type interactions.
	nodes         map[key.NodePublic]*tailcfg.Node
	users         map[key.NodePublic]*tailcfg.User
	logins        map[key.NodePublic]*tailcfg.Login
	updates       map[tailcfg.NodeID]chan updateType
	authPath      map[string]*AuthPath
	nodeKeyAuthed map[key.NodePublic]bool // key => true once authenticated
	pingReqsToAdd map[key.NodePublic]*tailcfg.PingRequest
	allExpired    bool // All nodes will be told their node key is expired.
}

// BaseURL returns the server's base URL, without trailing slash.
func (s *Server) BaseURL() string {
	if e := s.ExplicitBaseURL; e != "" {
		return e
	}
	if hs := s.HTTPTestServer; hs != nil {
		if hs.URL != "" {
			return hs.URL
		}
		panic("Server.HTTPTestServer not started")
	}
	panic("Server ExplicitBaseURL and HTTPTestServer both unset")
}

// NumNodes returns the number of nodes in the testcontrol server.
//
// This is useful when connecting a bunch of virtual machines to a testcontrol
// server to see how many of them connected successfully.
func (s *Server) NumNodes() int {
	s.mu.Lock()
	defer s.mu.Unlock()

	return len(s.nodes)
}

// condLocked lazily initializes and returns s.cond.
// s.mu must be held.
func (s *Server) condLocked() *sync.Cond {
	if s.cond == nil {
		s.cond = sync.NewCond(&s.mu)
	}
	return s.cond
}

// AwaitNodeInMapRequest waits for node k to be stuck in a map poll.
// It returns an error if and only if the context is done first.
func (s *Server) AwaitNodeInMapRequest(ctx context.Context, k key.NodePublic) error {
	s.mu.Lock()
	defer s.mu.Unlock()
	cond := s.condLocked()

	done := make(chan struct{})
	defer close(done)
	go func() {
		select {
		case <-done:
		case <-ctx.Done():
			cond.Broadcast()
		}
	}()

	for {
		node := s.nodeLocked(k)
		if node == nil {
			return errors.New("unknown node key")
		}
		if _, ok := s.updates[node.ID]; ok {
			return nil
		}
		cond.Wait()
		if err := ctx.Err(); err != nil {
			return err
		}
	}
}

// AddPingRequest sends the ping pr to nodeKeyDst. It reports whether it did so. That is,
// it reports whether nodeKeyDst was connected.
func (s *Server) AddPingRequest(nodeKeyDst key.NodePublic, pr *tailcfg.PingRequest) bool {
	s.mu.Lock()
	defer s.mu.Unlock()
	if s.pingReqsToAdd == nil {
		s.pingReqsToAdd = map[key.NodePublic]*tailcfg.PingRequest{}
	}
	// Now send the update to the channel
	node := s.nodeLocked(nodeKeyDst)
	if node == nil {
		return false
	}

	s.pingReqsToAdd[nodeKeyDst] = pr
	nodeID := node.ID
	oldUpdatesCh := s.updates[nodeID]
	return sendUpdate(oldUpdatesCh, updateDebugInjection)
}

// Mark the Node key of every node as expired
func (s *Server) SetExpireAllNodes(expired bool) {
	s.mu.Lock()
	defer s.mu.Unlock()

	s.allExpired = expired

	for _, node := range s.nodes {
		sendUpdate(s.updates[node.ID], updateSelfChanged)
	}
}

type AuthPath struct {
	nodeKey key.NodePublic

	closeOnce sync.Once
	ch        chan struct{}
	success   bool
}

func (ap *AuthPath) completeSuccessfully() {
	ap.success = true
	close(ap.ch)
}

// CompleteSuccessfully completes the login path successfully, as if
// the user did the whole auth dance.
func (ap *AuthPath) CompleteSuccessfully() {
	ap.closeOnce.Do(ap.completeSuccessfully)
}

func (s *Server) logf(format string, a ...interface{}) {
	if s.Logf != nil {
		s.Logf(format, a...)
	} else {
		log.Printf(format, a...)
	}
}

func (s *Server) initMux() {
	s.mux = http.NewServeMux()
	s.mux.HandleFunc("/", s.serveUnhandled)
	s.mux.HandleFunc("/key", s.serveKey)
	s.mux.HandleFunc("/machine/", s.serveMachine)
}

func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
	s.initMuxOnce.Do(s.initMux)
	s.mux.ServeHTTP(w, r)
}

func (s *Server) serveUnhandled(w http.ResponseWriter, r *http.Request) {
	var got bytes.Buffer
	r.Write(&got)
	go panic(fmt.Sprintf("testcontrol.Server received unhandled request: %s", got.Bytes()))
}

func (s *Server) publicKey() key.MachinePublic {
	pub, _ := s.keyPair()
	return pub
}

func (s *Server) privateKey() key.ControlPrivate {
	_, priv := s.keyPair()
	return priv
}

func (s *Server) keyPair() (pub key.MachinePublic, priv key.ControlPrivate) {
	s.mu.Lock()
	defer s.mu.Unlock()
	if s.pubKey.IsZero() {
		s.privKey = key.NewControl()
		s.pubKey = s.privKey.Public()
	}
	return s.pubKey, s.privKey
}

func (s *Server) serveKey(w http.ResponseWriter, r *http.Request) {
	w.Header().Set("Content-Type", "text/plain")
	w.WriteHeader(200)
	io.WriteString(w, s.publicKey().UntypedHexString())
}

func (s *Server) serveMachine(w http.ResponseWriter, r *http.Request) {
	mkeyStr := strings.TrimPrefix(r.URL.Path, "/machine/")
	rem := ""
	if i := strings.IndexByte(mkeyStr, '/'); i != -1 {
		rem = mkeyStr[i:]
		mkeyStr = mkeyStr[:i]
	}

	mkey, err := key.ParseMachinePublicUntyped(mem.S(mkeyStr))
	if err != nil {
		http.Error(w, "bad machine key hex", 400)
		return
	}

	if r.Method != "POST" {
		http.Error(w, "POST required", 400)
		return
	}

	switch rem {
	case "":
		s.serveRegister(w, r, mkey)
	case "/map":
		s.serveMap(w, r, mkey)
	default:
		s.serveUnhandled(w, r)
	}
}

// Node returns the node for nodeKey. It's always nil or cloned memory.
func (s *Server) Node(nodeKey key.NodePublic) *tailcfg.Node {
	s.mu.Lock()
	defer s.mu.Unlock()
	return s.nodeLocked(nodeKey)
}

// nodeLocked returns the node for nodeKey. It's always nil or cloned memory.
//
// s.mu must be held.
func (s *Server) nodeLocked(nodeKey key.NodePublic) *tailcfg.Node {
	return s.nodes[nodeKey].Clone()
}

// AddFakeNode injects a fake node into the server.
func (s *Server) AddFakeNode() {
	s.mu.Lock()
	defer s.mu.Unlock()
	if s.nodes == nil {
		s.nodes = make(map[key.NodePublic]*tailcfg.Node)
	}
	nk := key.NewNode().Public()
	mk := key.NewMachine().Public()
	dk := key.NewDisco().Public()
	r := nk.Raw32()
	id := int64(binary.LittleEndian.Uint64(r[:]))
	ip := netaddr.IPv4(r[0], r[1], r[2], r[3])
	addr := netaddr.IPPrefixFrom(ip, 32)
	s.nodes[nk] = &tailcfg.Node{
		ID:                tailcfg.NodeID(id),
		StableID:          tailcfg.StableNodeID(fmt.Sprintf("TESTCTRL%08x", id)),
		User:              tailcfg.UserID(id),
		Machine:           mk,
		Key:               nk,
		MachineAuthorized: true,
		DiscoKey:          dk,
		Addresses:         []netaddr.IPPrefix{addr},
		AllowedIPs:        []netaddr.IPPrefix{addr},
	}
	// TODO: send updates to other (non-fake?) nodes
}

func (s *Server) AllNodes() (nodes []*tailcfg.Node) {
	s.mu.Lock()
	defer s.mu.Unlock()
	for _, n := range s.nodes {
		nodes = append(nodes, n.Clone())
	}
	sort.Slice(nodes, func(i, j int) bool {
		return nodes[i].StableID < nodes[j].StableID
	})
	return nodes
}

func (s *Server) getUser(nodeKey key.NodePublic) (*tailcfg.User, *tailcfg.Login) {
	s.mu.Lock()
	defer s.mu.Unlock()
	if s.users == nil {
		s.users = map[key.NodePublic]*tailcfg.User{}
	}
	if s.logins == nil {
		s.logins = map[key.NodePublic]*tailcfg.Login{}
	}
	if u, ok := s.users[nodeKey]; ok {
		return u, s.logins[nodeKey]
	}
	id := tailcfg.UserID(len(s.users) + 1)
	domain := "fake-control.example.net"
	loginName := fmt.Sprintf("user-%d@%s", id, domain)
	displayName := fmt.Sprintf("User %d", id)
	login := &tailcfg.Login{
		ID:            tailcfg.LoginID(id),
		Provider:      "testcontrol",
		LoginName:     loginName,
		DisplayName:   displayName,
		ProfilePicURL: "https://tailscale.com/static/images/marketing/team-carney.jpg",
		Domain:        domain,
	}
	user := &tailcfg.User{
		ID:          id,
		LoginName:   loginName,
		DisplayName: displayName,
		Domain:      domain,
		Logins:      []tailcfg.LoginID{login.ID},
	}
	s.users[nodeKey] = user
	s.logins[nodeKey] = login
	return user, login
}

// authPathDone returns a close-only struct that's closed when the
// authPath ("/auth/XXXXXX") has authenticated.
func (s *Server) authPathDone(authPath string) <-chan struct{} {
	s.mu.Lock()
	defer s.mu.Unlock()
	if a, ok := s.authPath[authPath]; ok {
		return a.ch
	}
	return nil
}

func (s *Server) addAuthPath(authPath string, nodeKey key.NodePublic) {
	s.mu.Lock()
	defer s.mu.Unlock()
	if s.authPath == nil {
		s.authPath = map[string]*AuthPath{}
	}
	s.authPath[authPath] = &AuthPath{
		ch:      make(chan struct{}),
		nodeKey: nodeKey,
	}
}

// CompleteAuth marks the provided path or URL (containing
// "/auth/...")  as successfully authenticated, unblocking any
// requests blocked on that in serveRegister.
func (s *Server) CompleteAuth(authPathOrURL string) bool {
	i := strings.Index(authPathOrURL, "/auth/")
	if i == -1 {
		return false
	}
	authPath := authPathOrURL[i:]

	s.mu.Lock()
	defer s.mu.Unlock()
	ap, ok := s.authPath[authPath]
	if !ok {
		return false
	}
	if ap.nodeKey.IsZero() {
		panic("zero AuthPath.NodeKey")
	}
	if s.nodeKeyAuthed == nil {
		s.nodeKeyAuthed = map[key.NodePublic]bool{}
	}
	s.nodeKeyAuthed[ap.nodeKey] = true
	ap.CompleteSuccessfully()
	return true
}

func (s *Server) serveRegister(w http.ResponseWriter, r *http.Request, mkey key.MachinePublic) {
	msg, err := ioutil.ReadAll(io.LimitReader(r.Body, msgLimit))
	if err != nil {
		r.Body.Close()
		http.Error(w, fmt.Sprintf("bad map request read: %v", err), 400)
		return
	}
	r.Body.Close()

	var req tailcfg.RegisterRequest
	if err := s.decode(mkey, msg, &req); err != nil {
		go panic(fmt.Sprintf("serveRegister: decode: %v", err))
	}
	if req.Version != 1 {
		go panic(fmt.Sprintf("serveRegister: unsupported version: %d", req.Version))
	}
	if req.NodeKey.IsZero() {
		go panic("serveRegister: request has zero node key")
	}
	if s.Verbose {
		j, _ := json.MarshalIndent(req, "", "\t")
		log.Printf("Got %T: %s", req, j)
	}

	// If this is a followup request, wait until interactive followup URL visit complete.
	if req.Followup != "" {
		followupURL, err := url.Parse(req.Followup)
		if err != nil {
			panic(err)
		}
		doneCh := s.authPathDone(followupURL.Path)
		select {
		case <-r.Context().Done():
			return
		case <-doneCh:
		}
		// TODO(bradfitz): support a side test API to mark an
		// auth as failued so we can send an error response in
		// some follow-ups? For now all are successes.
	}

	nk := req.NodeKey

	user, login := s.getUser(nk)
	s.mu.Lock()
	if s.nodes == nil {
		s.nodes = map[key.NodePublic]*tailcfg.Node{}
	}

	machineAuthorized := true // TODO: add Server.RequireMachineAuth

	v4Prefix := netaddr.IPPrefixFrom(netaddr.IPv4(100, 64, uint8(tailcfg.NodeID(user.ID)>>8), uint8(tailcfg.NodeID(user.ID))), 32)
	v6Prefix := netaddr.IPPrefixFrom(tsaddr.Tailscale4To6(v4Prefix.IP()), 128)

	allowedIPs := []netaddr.IPPrefix{
		v4Prefix,
		v6Prefix,
	}

	s.nodes[nk] = &tailcfg.Node{
		ID:                tailcfg.NodeID(user.ID),
		StableID:          tailcfg.StableNodeID(fmt.Sprintf("TESTCTRL%08x", int(user.ID))),
		User:              user.ID,
		Machine:           mkey,
		Key:               req.NodeKey,
		MachineAuthorized: machineAuthorized,
		Addresses:         allowedIPs,
		AllowedIPs:        allowedIPs,
		Hostinfo:          req.Hostinfo.View(),
	}
	requireAuth := s.RequireAuth
	if requireAuth && s.nodeKeyAuthed[nk] {
		requireAuth = false
	}
	allExpired := s.allExpired
	s.mu.Unlock()

	authURL := ""
	if requireAuth {
		randHex := make([]byte, 10)
		crand.Read(randHex)
		authPath := fmt.Sprintf("/auth/%x", randHex)
		s.addAuthPath(authPath, nk)
		authURL = s.BaseURL() + authPath
	}

	res, err := s.encode(mkey, false, tailcfg.RegisterResponse{
		User:              *user,
		Login:             *login,
		NodeKeyExpired:    allExpired,
		MachineAuthorized: machineAuthorized,
		AuthURL:           authURL,
	})
	if err != nil {
		go panic(fmt.Sprintf("serveRegister: encode: %v", err))
	}
	w.WriteHeader(200)
	w.Write(res)
}

// updateType indicates why a long-polling map request is being woken
// up for an update.
type updateType int

const (
	// updatePeerChanged is an update that a peer has changed.
	updatePeerChanged updateType = iota + 1

	// updateSelfChanged is an update that the node changed itself
	// via a lite endpoint update. These ones are never dup-suppressed,
	// as the client is expecting an answer regardless.
	updateSelfChanged

	// updateDebugInjection is an update used for PingRequests
	updateDebugInjection
)

func (s *Server) updateLocked(source string, peers []tailcfg.NodeID) {
	for _, peer := range peers {
		sendUpdate(s.updates[peer], updatePeerChanged)
	}
}

// sendUpdate sends updateType to dst if dst is non-nil and
// has capacity. It reports whether a value was sent.
func sendUpdate(dst chan<- updateType, updateType updateType) bool {
	if dst == nil {
		return false
	}
	// The dst channel has a buffer size of 1.
	// If we fail to insert an update into the buffer that
	// means there is already an update pending.
	select {
	case dst <- updateType:
		return true
	default:
		return false
	}
}

func (s *Server) UpdateNode(n *tailcfg.Node) (peersToUpdate []tailcfg.NodeID) {
	s.mu.Lock()
	defer s.mu.Unlock()
	if n.Key.IsZero() {
		panic("zero nodekey")
	}
	s.nodes[n.Key] = n.Clone()
	for _, n2 := range s.nodes {
		if n.ID != n2.ID {
			peersToUpdate = append(peersToUpdate, n2.ID)
		}
	}
	return peersToUpdate
}

func (s *Server) incrInServeMap(delta int) {
	s.mu.Lock()
	defer s.mu.Unlock()
	s.inServeMap += delta
}

// InServeMap returns the number of clients currently in a MapRequest HTTP handler.
func (s *Server) InServeMap() int {
	s.mu.Lock()
	defer s.mu.Unlock()
	return s.inServeMap
}

func (s *Server) serveMap(w http.ResponseWriter, r *http.Request, mkey key.MachinePublic) {
	s.incrInServeMap(1)
	defer s.incrInServeMap(-1)
	ctx := r.Context()

	msg, err := ioutil.ReadAll(io.LimitReader(r.Body, msgLimit))
	if err != nil {
		r.Body.Close()
		http.Error(w, fmt.Sprintf("bad map request read: %v", err), 400)
		return
	}
	r.Body.Close()

	req := new(tailcfg.MapRequest)
	if err := s.decode(mkey, msg, req); err != nil {
		go panic(fmt.Sprintf("bad map request: %v", err))
	}

	jitter := time.Duration(rand.Intn(8000)) * time.Millisecond
	keepAlive := 50*time.Second + jitter

	node := s.Node(req.NodeKey)
	if node == nil {
		http.Error(w, "node not found", 400)
		return
	}
	if node.Machine != mkey {
		http.Error(w, "node doesn't match machine key", 400)
		return
	}

	var peersToUpdate []tailcfg.NodeID
	if !req.ReadOnly {
		endpoints := filterInvalidIPv6Endpoints(req.Endpoints)
		node.Endpoints = endpoints
		node.DiscoKey = req.DiscoKey
		if req.Hostinfo != nil {
			node.Hostinfo = req.Hostinfo.View()
			if ni := node.Hostinfo.NetInfo(); ni.Valid() {
				if ni.PreferredDERP() != 0 {
					node.DERP = fmt.Sprintf("127.3.3.40:%d", ni.PreferredDERP())
				}
			}
		}
		peersToUpdate = s.UpdateNode(node)
	}

	nodeID := node.ID

	s.mu.Lock()
	updatesCh := make(chan updateType, 1)
	oldUpdatesCh := s.updates[nodeID]
	if breakSameNodeMapResponseStreams(req) {
		if oldUpdatesCh != nil {
			close(oldUpdatesCh)
		}
		if s.updates == nil {
			s.updates = map[tailcfg.NodeID]chan updateType{}
		}
		s.updates[nodeID] = updatesCh
	} else {
		sendUpdate(oldUpdatesCh, updateSelfChanged)
	}
	s.updateLocked("serveMap", peersToUpdate)
	s.condLocked().Broadcast()
	s.mu.Unlock()

	// ReadOnly implies no streaming, as it doesn't
	// register an updatesCh to get updates.
	streaming := req.Stream && !req.ReadOnly
	compress := req.Compress != ""

	w.WriteHeader(200)
	for {
		res, err := s.MapResponse(req)
		if err != nil {
			// TODO: log
			return
		}
		if res == nil {
			return // done
		}

		s.mu.Lock()
		allExpired := s.allExpired
		s.mu.Unlock()
		if allExpired {
			res.Node.KeyExpiry = time.Now().Add(-1 * time.Minute)
		}
		// TODO: add minner if/when needed
		resBytes, err := json.Marshal(res)
		if err != nil {
			s.logf("json.Marshal: %v", err)
			return
		}
		if err := s.sendMapMsg(w, mkey, compress, resBytes); err != nil {
			return
		}
		if !streaming {
			return
		}
	keepAliveLoop:
		for {
			var keepAliveTimer *time.Timer
			var keepAliveTimerCh <-chan time.Time
			if keepAlive > 0 {
				keepAliveTimer = time.NewTimer(keepAlive)
				keepAliveTimerCh = keepAliveTimer.C
			}
			select {
			case <-ctx.Done():
				if keepAliveTimer != nil {
					keepAliveTimer.Stop()
				}
				return
			case _, ok := <-updatesCh:
				if !ok {
					// replaced by new poll request
					return
				}
				break keepAliveLoop
			case <-keepAliveTimerCh:
				if err := s.sendMapMsg(w, mkey, compress, keepAliveMsg); err != nil {
					return
				}
			}
		}
	}
}

var keepAliveMsg = &struct {
	KeepAlive bool
}{
	KeepAlive: true,
}

// MapResponse generates a MapResponse for a MapRequest.
//
// No updates to s are done here.
func (s *Server) MapResponse(req *tailcfg.MapRequest) (res *tailcfg.MapResponse, err error) {
	nk := req.NodeKey
	node := s.Node(nk)
	if node == nil {
		// node key rotated away (once test server supports that)
		return nil, nil
	}
	user, _ := s.getUser(nk)
	t := time.Date(2020, 8, 3, 0, 0, 0, 1, time.UTC)
	res = &tailcfg.MapResponse{
		Node:            node,
		DERPMap:         s.DERPMap,
		Domain:          string(user.Domain),
		CollectServices: "true",
		PacketFilter:    tailcfg.FilterAllowAll,
		Debug: &tailcfg.Debug{
			DisableUPnP: "true",
		},
		DNSConfig:   s.DNSConfig,
		ControlTime: &t,
	}
	for _, p := range s.AllNodes() {
		if p.StableID != node.StableID {
			res.Peers = append(res.Peers, p)
		}
	}
	sort.Slice(res.Peers, func(i, j int) bool {
		return res.Peers[i].ID < res.Peers[j].ID
	})

	v4Prefix := netaddr.IPPrefixFrom(netaddr.IPv4(100, 64, uint8(tailcfg.NodeID(user.ID)>>8), uint8(tailcfg.NodeID(user.ID))), 32)
	v6Prefix := netaddr.IPPrefixFrom(tsaddr.Tailscale4To6(v4Prefix.IP()), 128)

	res.Node.Addresses = []netaddr.IPPrefix{
		v4Prefix,
		v6Prefix,
	}
	res.Node.AllowedIPs = res.Node.Addresses

	// Consume the PingRequest while protected by mutex if it exists
	s.mu.Lock()
	if pr, ok := s.pingReqsToAdd[nk]; ok {
		res.PingRequest = pr
		delete(s.pingReqsToAdd, nk)
	}
	s.mu.Unlock()
	return res, nil
}

func (s *Server) sendMapMsg(w http.ResponseWriter, mkey key.MachinePublic, compress bool, msg interface{}) error {
	resBytes, err := s.encode(mkey, compress, msg)
	if err != nil {
		return err
	}
	if len(resBytes) > 16<<20 {
		return fmt.Errorf("map message too big: %d", len(resBytes))
	}
	var siz [4]byte
	binary.LittleEndian.PutUint32(siz[:], uint32(len(resBytes)))
	if _, err := w.Write(siz[:]); err != nil {
		return err
	}
	if _, err := w.Write(resBytes); err != nil {
		return err
	}
	if f, ok := w.(http.Flusher); ok {
		f.Flush()
	} else {
		s.logf("[unexpected] ResponseWriter %T is not a Flusher", w)
	}
	return nil
}

func (s *Server) decode(mkey key.MachinePublic, msg []byte, v interface{}) error {
	if len(msg) == msgLimit {
		return errors.New("encrypted message too long")
	}

	decrypted, ok := s.privateKey().OpenFrom(mkey, msg)
	if !ok {
		return errors.New("can't decrypt request")
	}
	return json.Unmarshal(decrypted, v)
}

var zstdEncoderPool = &sync.Pool{
	New: func() interface{} {
		encoder, err := smallzstd.NewEncoder(nil, zstd.WithEncoderLevel(zstd.SpeedFastest))
		if err != nil {
			panic(err)
		}
		return encoder
	},
}

func (s *Server) encode(mkey key.MachinePublic, compress bool, v interface{}) (b []byte, err error) {
	var isBytes bool
	if b, isBytes = v.([]byte); !isBytes {
		b, err = json.Marshal(v)
		if err != nil {
			return nil, err
		}
	}
	if compress {
		encoder := zstdEncoderPool.Get().(*zstd.Encoder)
		b = encoder.EncodeAll(b, nil)
		encoder.Close()
		zstdEncoderPool.Put(encoder)
	}
	return s.privateKey().SealTo(mkey, b), nil
}

// filterInvalidIPv6Endpoints removes invalid IPv6 endpoints from eps,
// modify the slice in place, returning the potentially smaller subset (aliasing
// the original memory).
//
// Two types of IPv6 endpoints are considered invalid: link-local
// addresses, and anything with a zone.
func filterInvalidIPv6Endpoints(eps []string) []string {
	clean := eps[:0]
	for _, ep := range eps {
		if keepClientEndpoint(ep) {
			clean = append(clean, ep)
		}
	}
	return clean
}

func keepClientEndpoint(ep string) bool {
	ipp, err := netaddr.ParseIPPort(ep)
	if err != nil {
		// Shouldn't have made it this far if we unmarshalled
		// the incoming JSON response.
		return false
	}
	ip := ipp.IP()
	if ip.Zone() != "" {
		return false
	}
	if ip.Is6() && ip.IsLinkLocalUnicast() {
		// We let clients send these for now, but
		// tailscaled doesn't know how to use them yet
		// so we filter them out for now. A future
		// MapRequest.Version might signal that
		// clients know how to use them (e.g. try all
		// local scopes).
		return false
	}
	return true
}

// breakSameNodeMapResponseStreams reports whether req should break a
// prior long-polling MapResponse stream (if active) from the same
// node ID.
func breakSameNodeMapResponseStreams(req *tailcfg.MapRequest) bool {
	if req.ReadOnly {
		// Don't register our updatesCh for closability
		// nor close another peer's if we're a read-only request.
		return false
	}
	if !req.Stream && req.OmitPeers {
		// Likewise, if we're not streaming and not asking for peers,
		// (but still mutable, without Readonly set), consider this an endpoint
		// update request only, and don't close any existing map response
		// for this nodeID. It's likely the same client with a built-up
		// compression context. We want to let them update their
		// new endpoints with us without breaking that other long-running
		// map response.
		return false
	}
	return true
}