// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package controlclient

//go:generate go run tailscale.com/cmd/cloner -type=Persist -output=direct_clone.go

import (
	"bytes"
	"context"
	"crypto/rand"
	"encoding/binary"
	"encoding/json"
	"errors"
	"flag"
	"fmt"
	"io"
	"io/ioutil"
	"log"
	"net/http"
	"net/url"
	"os"
	"os/exec"
	"reflect"
	"runtime"
	"sort"
	"strconv"
	"strings"
	"sync"
	"sync/atomic"
	"time"

	"github.com/tailscale/wireguard-go/wgcfg"
	"golang.org/x/crypto/nacl/box"
	"golang.org/x/oauth2"
	"inet.af/netaddr"
	"tailscale.com/log/logheap"
	"tailscale.com/net/dnscache"
	"tailscale.com/net/netns"
	"tailscale.com/net/tlsdial"
	"tailscale.com/net/tshttpproxy"
	"tailscale.com/tailcfg"
	"tailscale.com/types/logger"
	"tailscale.com/types/opt"
	"tailscale.com/types/structs"
	"tailscale.com/version"
)

type Persist struct {
	_ structs.Incomparable

	// LegacyFrontendPrivateMachineKey is here temporarily
	// (starting 2020-09-28) during migration of Windows users'
	// machine keys from frontend storage to the backend. On the
	// first LocalBackend.Start call, the backend will initialize
	// the real (backend-owned) machine key from the frontend's
	// provided value (if non-zero), picking a new random one if
	// needed. This field should be considered read-only from GUI
	// frontends. The real value should not be written back in
	// this field, lest the frontend persist it to disk.
	LegacyFrontendPrivateMachineKey wgcfg.PrivateKey `json:"PrivateMachineKey"`

	PrivateNodeKey    wgcfg.PrivateKey
	OldPrivateNodeKey wgcfg.PrivateKey // needed to request key rotation
	Provider          string
	LoginName         string
}

func (p *Persist) Equals(p2 *Persist) bool {
	if p == nil && p2 == nil {
		return true
	}
	if p == nil || p2 == nil {
		return false
	}

	return p.LegacyFrontendPrivateMachineKey.Equal(p2.LegacyFrontendPrivateMachineKey) &&
		p.PrivateNodeKey.Equal(p2.PrivateNodeKey) &&
		p.OldPrivateNodeKey.Equal(p2.OldPrivateNodeKey) &&
		p.Provider == p2.Provider &&
		p.LoginName == p2.LoginName
}

func (p *Persist) Pretty() string {
	var mk, ok, nk wgcfg.Key
	if !p.LegacyFrontendPrivateMachineKey.IsZero() {
		mk = p.LegacyFrontendPrivateMachineKey.Public()
	}
	if !p.OldPrivateNodeKey.IsZero() {
		ok = p.OldPrivateNodeKey.Public()
	}
	if !p.PrivateNodeKey.IsZero() {
		nk = p.PrivateNodeKey.Public()
	}
	ss := func(k wgcfg.Key) string {
		if k.IsZero() {
			return ""
		}
		return k.ShortString()
	}
	return fmt.Sprintf("Persist{lm=%v, o=%v, n=%v u=%#v}",
		ss(mk), ss(ok), ss(nk), p.LoginName)
}

// Direct is the client that connects to a tailcontrol server for a node.
type Direct struct {
	httpc           *http.Client // HTTP client used to talk to tailcontrol
	serverURL       string       // URL of the tailcontrol server
	timeNow         func() time.Time
	lastPrintMap    time.Time
	newDecompressor func() (Decompressor, error)
	keepAlive       bool
	logf            logger.Logf
	discoPubKey     tailcfg.DiscoKey
	machinePrivKey  wgcfg.PrivateKey
	debugFlags      []string

	mu           sync.Mutex // mutex guards the following fields
	serverKey    wgcfg.Key
	persist      Persist
	authKey      string
	tryingNewKey wgcfg.PrivateKey
	expiry       *time.Time
	// hostinfo is mutated in-place while mu is held.
	hostinfo      *tailcfg.Hostinfo // always non-nil
	endpoints     []string
	everEndpoints bool   // whether we've ever had non-empty endpoints
	localPort     uint16 // or zero to mean auto
}

type Options struct {
	Persist           Persist           // initial persistent data
	MachinePrivateKey wgcfg.PrivateKey  // the machine key to use
	ServerURL         string            // URL of the tailcontrol server
	AuthKey           string            // optional node auth key for auto registration
	TimeNow           func() time.Time  // time.Now implementation used by Client
	Hostinfo          *tailcfg.Hostinfo // non-nil passes ownership, nil means to use default using os.Hostname, etc
	DiscoPublicKey    tailcfg.DiscoKey
	NewDecompressor   func() (Decompressor, error)
	KeepAlive         bool
	Logf              logger.Logf
	HTTPTestClient    *http.Client // optional HTTP client to use (for tests only)
	DebugFlags        []string     // debug settings to send to control
}

type Decompressor interface {
	DecodeAll(input, dst []byte) ([]byte, error)
	Close()
}

// NewDirect returns a new Direct client.
func NewDirect(opts Options) (*Direct, error) {
	if opts.ServerURL == "" {
		return nil, errors.New("controlclient.New: no server URL specified")
	}
	if opts.MachinePrivateKey.IsZero() {
		return nil, errors.New("controlclient.New: no MachinePrivateKey specified")
	}
	opts.ServerURL = strings.TrimRight(opts.ServerURL, "/")
	serverURL, err := url.Parse(opts.ServerURL)
	if err != nil {
		return nil, err
	}
	if opts.TimeNow == nil {
		opts.TimeNow = time.Now
	}
	if opts.Logf == nil {
		// TODO(apenwarr): remove this default and fail instead.
		// TODO(bradfitz): ... but then it shouldn't be in Options.
		opts.Logf = log.Printf
	}

	httpc := opts.HTTPTestClient
	if httpc == nil {
		dnsCache := &dnscache.Resolver{
			Forward:     dnscache.Get().Forward, // use default cache's forwarder
			UseLastGood: true,
		}
		dialer := netns.NewDialer()
		tr := http.DefaultTransport.(*http.Transport).Clone()
		tr.Proxy = tshttpproxy.ProxyFromEnvironment
		tshttpproxy.SetTransportGetProxyConnectHeader(tr)
		tr.DialContext = dnscache.Dialer(dialer.DialContext, dnsCache)
		tr.ForceAttemptHTTP2 = true
		tr.TLSClientConfig = tlsdial.Config(serverURL.Host, tr.TLSClientConfig)
		httpc = &http.Client{Transport: tr}
	}

	c := &Direct{
		httpc:           httpc,
		machinePrivKey:  opts.MachinePrivateKey,
		serverURL:       opts.ServerURL,
		timeNow:         opts.TimeNow,
		logf:            opts.Logf,
		newDecompressor: opts.NewDecompressor,
		keepAlive:       opts.KeepAlive,
		persist:         opts.Persist,
		authKey:         opts.AuthKey,
		discoPubKey:     opts.DiscoPublicKey,
		debugFlags:      opts.DebugFlags,
	}
	if opts.Hostinfo == nil {
		c.SetHostinfo(NewHostinfo())
	} else {
		c.SetHostinfo(opts.Hostinfo)
	}
	return c, nil
}

var osVersion func() string // non-nil on some platforms

func NewHostinfo() *tailcfg.Hostinfo {
	hostname, _ := os.Hostname()
	var osv string
	if osVersion != nil {
		osv = osVersion()
	}
	return &tailcfg.Hostinfo{
		IPNVersion: version.Long,
		Hostname:   hostname,
		OS:         version.OS(),
		OSVersion:  osv,
		GoArch:     runtime.GOARCH,
	}
}

// SetHostinfo clones the provided Hostinfo and remembers it for the
// next update. It reports whether the Hostinfo has changed.
func (c *Direct) SetHostinfo(hi *tailcfg.Hostinfo) bool {
	if hi == nil {
		panic("nil Hostinfo")
	}
	c.mu.Lock()
	defer c.mu.Unlock()

	if hi.Equal(c.hostinfo) {
		return false
	}
	c.hostinfo = hi.Clone()
	j, _ := json.Marshal(c.hostinfo)
	c.logf("HostInfo: %s", j)
	return true
}

// SetNetInfo clones the provided NetInfo and remembers it for the
// next update. It reports whether the NetInfo has changed.
func (c *Direct) SetNetInfo(ni *tailcfg.NetInfo) bool {
	if ni == nil {
		panic("nil NetInfo")
	}
	c.mu.Lock()
	defer c.mu.Unlock()

	if c.hostinfo == nil {
		c.logf("[unexpected] SetNetInfo called with no HostInfo; ignoring NetInfo update: %+v", ni)
		return false
	}
	if reflect.DeepEqual(ni, c.hostinfo.NetInfo) {
		return false
	}
	c.hostinfo.NetInfo = ni.Clone()
	return true
}

func (c *Direct) GetPersist() Persist {
	c.mu.Lock()
	defer c.mu.Unlock()
	return c.persist
}

type LoginFlags int

const (
	LoginDefault     = LoginFlags(0)
	LoginInteractive = LoginFlags(1 << iota) // force user login and key refresh
)

func (c *Direct) TryLogout(ctx context.Context) error {
	c.logf("direct.TryLogout()")

	c.mu.Lock()
	defer c.mu.Unlock()

	// TODO(crawshaw): Tell the server. This node key should be
	// immediately invalidated.
	//if !c.persist.PrivateNodeKey.IsZero() {
	//}
	c.persist = Persist{}
	return nil
}

func (c *Direct) TryLogin(ctx context.Context, t *oauth2.Token, flags LoginFlags) (url string, err error) {
	c.logf("direct.TryLogin(token=%v, flags=%v)", t != nil, flags)
	return c.doLoginOrRegen(ctx, t, flags, false, "")
}

func (c *Direct) WaitLoginURL(ctx context.Context, url string) (newUrl string, err error) {
	c.logf("direct.WaitLoginURL")
	return c.doLoginOrRegen(ctx, nil, LoginDefault, false, url)
}

func (c *Direct) doLoginOrRegen(ctx context.Context, t *oauth2.Token, flags LoginFlags, regen bool, url string) (newUrl string, err error) {
	mustregen, url, err := c.doLogin(ctx, t, flags, regen, url)
	if err != nil {
		return url, err
	}
	if mustregen {
		_, url, err = c.doLogin(ctx, t, flags, true, url)
	}
	return url, err
}

func (c *Direct) doLogin(ctx context.Context, t *oauth2.Token, flags LoginFlags, regen bool, url string) (mustregen bool, newurl string, err error) {
	c.mu.Lock()
	persist := c.persist
	tryingNewKey := c.tryingNewKey
	serverKey := c.serverKey
	authKey := c.authKey
	hostinfo := c.hostinfo.Clone()
	backendLogID := hostinfo.BackendLogID
	expired := c.expiry != nil && !c.expiry.IsZero() && c.expiry.Before(c.timeNow())
	c.mu.Unlock()

	if c.machinePrivKey.IsZero() {
		return false, "", errors.New("controlclient.Direct requires a machine private key")
	}

	if expired {
		c.logf("Old key expired -> regen=true")
		regen = true
	}
	if (flags & LoginInteractive) != 0 {
		c.logf("LoginInteractive -> regen=true")
		regen = true
	}

	c.logf("doLogin(regen=%v, hasUrl=%v)", regen, url != "")
	if serverKey == (wgcfg.Key{}) {
		var err error
		serverKey, err = loadServerKey(ctx, c.httpc, c.serverURL)
		if err != nil {
			return regen, url, err
		}

		c.mu.Lock()
		c.serverKey = serverKey
		c.mu.Unlock()
	}

	var oldNodeKey wgcfg.Key
	if url != "" {
	} else if regen || persist.PrivateNodeKey.IsZero() {
		c.logf("Generating a new nodekey.")
		persist.OldPrivateNodeKey = persist.PrivateNodeKey
		key, err := wgcfg.NewPrivateKey()
		if err != nil {
			c.logf("login keygen: %v", err)
			return regen, url, err
		}
		tryingNewKey = key
	} else {
		// Try refreshing the current key first
		tryingNewKey = persist.PrivateNodeKey
	}
	if !persist.OldPrivateNodeKey.IsZero() {
		oldNodeKey = persist.OldPrivateNodeKey.Public()
	}

	if tryingNewKey.IsZero() {
		log.Fatalf("tryingNewKey is empty, give up")
	}
	if backendLogID == "" {
		err = errors.New("hostinfo: BackendLogID missing")
		return regen, url, err
	}
	request := tailcfg.RegisterRequest{
		Version:    1,
		OldNodeKey: tailcfg.NodeKey(oldNodeKey),
		NodeKey:    tailcfg.NodeKey(tryingNewKey.Public()),
		Hostinfo:   hostinfo,
		Followup:   url,
	}
	c.logf("RegisterReq: onode=%v node=%v fup=%v",
		request.OldNodeKey.ShortString(),
		request.NodeKey.ShortString(), url != "")
	request.Auth.Oauth2Token = t
	request.Auth.Provider = persist.Provider
	request.Auth.LoginName = persist.LoginName
	request.Auth.AuthKey = authKey
	bodyData, err := encode(request, &serverKey, &c.machinePrivKey)
	if err != nil {
		return regen, url, err
	}
	body := bytes.NewReader(bodyData)

	u := fmt.Sprintf("%s/machine/%s", c.serverURL, c.machinePrivKey.Public().HexString())
	req, err := http.NewRequest("POST", u, body)
	if err != nil {
		return regen, url, err
	}
	req = req.WithContext(ctx)

	res, err := c.httpc.Do(req)
	if err != nil {
		return regen, url, fmt.Errorf("register request: %v", err)
	}
	if res.StatusCode != 200 {
		msg, _ := ioutil.ReadAll(res.Body)
		res.Body.Close()
		return regen, url, fmt.Errorf("register request: http %d: %.200s",
			res.StatusCode, strings.TrimSpace(string(msg)))
	}
	resp := tailcfg.RegisterResponse{}
	if err := decode(res, &resp, &serverKey, &c.machinePrivKey); err != nil {
		c.logf("error decoding RegisterResponse with server key %s and machine key %s: %v", serverKey, c.machinePrivKey.Public(), err)
		return regen, url, fmt.Errorf("register request: %v", err)
	}
	// Log without PII:
	c.logf("RegisterReq: got response; nodeKeyExpired=%v, machineAuthorized=%v; authURL=%v",
		resp.NodeKeyExpired, resp.MachineAuthorized, resp.AuthURL != "")

	if resp.NodeKeyExpired {
		if regen {
			return true, "", fmt.Errorf("weird: regen=true but server says NodeKeyExpired: %v", request.NodeKey)
		}
		c.logf("server reports new node key %v has expired",
			request.NodeKey.ShortString())
		return true, "", nil
	}
	if persist.Provider == "" {
		persist.Provider = resp.Login.Provider
	}
	if persist.LoginName == "" {
		persist.LoginName = resp.Login.LoginName
	}

	// TODO(crawshaw): RegisterResponse should be able to mechanically
	// communicate some extra instructions from the server:
	//	- new node key required
	//	- machine key no longer supported
	//	- user is disabled

	if resp.AuthURL != "" {
		c.logf("AuthURL is %v", resp.AuthURL)
	} else {
		c.logf("No AuthURL")
	}

	c.mu.Lock()
	if resp.AuthURL == "" {
		// key rotation is complete
		persist.PrivateNodeKey = tryingNewKey
	} else {
		// save it for the retry-with-URL
		c.tryingNewKey = tryingNewKey
	}
	c.persist = persist
	c.mu.Unlock()

	if err != nil {
		return regen, "", err
	}
	if ctx.Err() != nil {
		return regen, "", ctx.Err()
	}
	return false, resp.AuthURL, nil
}

func sameStrings(a, b []string) bool {
	if len(a) != len(b) {
		return false
	}
	for i := range a {
		if a[i] != b[i] {
			return false
		}
	}
	return true
}

// newEndpoints acquires c.mu and sets the local port and endpoints and reports
// whether they've changed.
//
// It does not retain the provided slice.
func (c *Direct) newEndpoints(localPort uint16, endpoints []string) (changed bool) {
	c.mu.Lock()
	defer c.mu.Unlock()

	// Nothing new?
	if c.localPort == localPort && sameStrings(c.endpoints, endpoints) {
		return false // unchanged
	}
	c.logf("client.newEndpoints(%v, %v)", localPort, endpoints)
	c.localPort = localPort
	c.endpoints = append(c.endpoints[:0], endpoints...)
	if len(endpoints) > 0 {
		c.everEndpoints = true
	}
	return true // changed
}

// SetEndpoints updates the list of locally advertised endpoints.
// It won't be replicated to the server until a *fresh* call to PollNetMap().
// You don't need to restart PollNetMap if we return changed==false.
func (c *Direct) SetEndpoints(localPort uint16, endpoints []string) (changed bool) {
	// (no log message on function entry, because it clutters the logs
	//  if endpoints haven't changed. newEndpoints() will log it.)
	return c.newEndpoints(localPort, endpoints)
}

func inTest() bool { return flag.Lookup("test.v") != nil }

// PollNetMap makes a /map request to download the network map, calling cb with
// each new netmap.
//
// maxPolls is how many network maps to download; common values are 1
// or -1 (to keep a long-poll query open to the server).
func (c *Direct) PollNetMap(ctx context.Context, maxPolls int, cb func(*NetworkMap)) error {
	c.mu.Lock()
	persist := c.persist
	serverURL := c.serverURL
	serverKey := c.serverKey
	hostinfo := c.hostinfo.Clone()
	backendLogID := hostinfo.BackendLogID
	localPort := c.localPort
	ep := append([]string(nil), c.endpoints...)
	everEndpoints := c.everEndpoints
	c.mu.Unlock()

	if backendLogID == "" {
		return errors.New("hostinfo: BackendLogID missing")
	}

	allowStream := maxPolls != 1
	c.logf("PollNetMap: stream=%v :%v ep=%v", allowStream, localPort, ep)

	vlogf := logger.Discard
	if Debug.NetMap {
		vlogf = c.logf
	}

	request := tailcfg.MapRequest{
		Version:    5,
		KeepAlive:  c.keepAlive,
		NodeKey:    tailcfg.NodeKey(persist.PrivateNodeKey.Public()),
		DiscoKey:   c.discoPubKey,
		Endpoints:  ep,
		Stream:     allowStream,
		Hostinfo:   hostinfo,
		DebugFlags: c.debugFlags,
	}
	if hostinfo != nil && ipForwardingBroken(hostinfo.RoutableIPs) {
		old := request.DebugFlags
		request.DebugFlags = append(old[:len(old):len(old)], "warn-ip-forwarding-off")
	}
	if c.newDecompressor != nil {
		request.Compress = "zstd"
	}
	// On initial startup before we know our endpoints, set the ReadOnly flag
	// to tell the control server not to distribute out our (empty) endpoints to peers.
	// Presumably we'll learn our endpoints in a half second and do another post
	// with useful results. The first POST just gets us the DERP map which we
	// need to do the STUN queries to discover our endpoints.
	// TODO(bradfitz): we skip this optimization in tests, though,
	// because the e2e tests are currently hyperspecific about the
	// ordering of things. The e2e tests need love.
	if len(ep) == 0 && !everEndpoints && !inTest() {
		request.ReadOnly = true
	}

	bodyData, err := encode(request, &serverKey, &c.machinePrivKey)
	if err != nil {
		vlogf("netmap: encode: %v", err)
		return err
	}

	ctx, cancel := context.WithCancel(ctx)
	defer cancel()

	machinePubKey := tailcfg.MachineKey(c.machinePrivKey.Public())
	t0 := time.Now()
	u := fmt.Sprintf("%s/machine/%s/map", serverURL, machinePubKey.HexString())

	req, err := http.NewRequestWithContext(ctx, "POST", u, bytes.NewReader(bodyData))
	if err != nil {
		return err
	}

	res, err := c.httpc.Do(req)
	if err != nil {
		vlogf("netmap: Do: %v", err)
		return err
	}
	vlogf("netmap: Do = %v after %v", res.StatusCode, time.Since(t0).Round(time.Millisecond))
	if res.StatusCode != 200 {
		msg, _ := ioutil.ReadAll(res.Body)
		res.Body.Close()
		return fmt.Errorf("initial fetch failed %d: %.200s",
			res.StatusCode, strings.TrimSpace(string(msg)))
	}
	defer res.Body.Close()

	// If we go more than pollTimeout without hearing from the server,
	// end the long poll. We should be receiving a keep alive ping
	// every minute.
	const pollTimeout = 120 * time.Second
	timeout := time.NewTimer(pollTimeout)
	timeoutReset := make(chan struct{})
	pollDone := make(chan struct{})
	defer close(pollDone)
	go func() {
		for {
			select {
			case <-pollDone:
				vlogf("netmap: ending timeout goroutine")
				return
			case <-timeout.C:
				c.logf("map response long-poll timed out!")
				cancel()
				return
			case <-timeoutReset:
				if !timeout.Stop() {
					select {
					case <-timeout.C:
					case <-pollDone:
						vlogf("netmap: ending timeout goroutine")
						return
					}
				}
				vlogf("netmap: reset timeout timer")
				timeout.Reset(pollTimeout)
			}
		}
	}()

	var lastDERPMap *tailcfg.DERPMap
	var lastUserProfile = map[tailcfg.UserID]tailcfg.UserProfile{}

	// If allowStream, then the server will use an HTTP long poll to
	// return incremental results. There is always one response right
	// away, followed by a delay, and eventually others.
	// If !allowStream, it'll still send the first result in exactly
	// the same format before just closing the connection.
	// We can use this same read loop either way.
	var msg []byte
	var previousPeers []*tailcfg.Node // for delta-purposes
	for i := 0; i < maxPolls || maxPolls < 0; i++ {
		vlogf("netmap: starting size read after %v (poll %v)", time.Since(t0).Round(time.Millisecond), i)
		var siz [4]byte
		if _, err := io.ReadFull(res.Body, siz[:]); err != nil {
			vlogf("netmap: size read error after %v: %v", time.Since(t0).Round(time.Millisecond), err)
			return err
		}
		size := binary.LittleEndian.Uint32(siz[:])
		vlogf("netmap: read size %v after %v", size, time.Since(t0).Round(time.Millisecond))
		msg = append(msg[:0], make([]byte, size)...)
		if _, err := io.ReadFull(res.Body, msg); err != nil {
			vlogf("netmap: body read error: %v", err)
			return err
		}
		vlogf("netmap: read body after %v", time.Since(t0).Round(time.Millisecond))

		var resp tailcfg.MapResponse
		if err := c.decodeMsg(msg, &resp); err != nil {
			vlogf("netmap: decode error: %v")
			return err
		}

		if resp.KeepAlive {
			vlogf("netmap: got keep-alive")
		} else {
			vlogf("netmap: got new map")
		}
		select {
		case timeoutReset <- struct{}{}:
			vlogf("netmap: sent timer reset")
		case <-ctx.Done():
			c.logf("netmap: not resetting timer; context done: %v", ctx.Err())
			return ctx.Err()
		}
		if resp.KeepAlive {
			continue
		}

		undeltaPeers(&resp, previousPeers)
		previousPeers = cloneNodes(resp.Peers) // defensive/lazy clone, since this escapes to who knows where
		for _, up := range resp.UserProfiles {
			lastUserProfile[up.ID] = up
		}

		if resp.DERPMap != nil {
			vlogf("netmap: new map contains DERP map")
			lastDERPMap = resp.DERPMap
		}
		if resp.Debug != nil {
			if resp.Debug.LogHeapPprof {
				go logheap.LogHeap(resp.Debug.LogHeapURL)
			}
			setControlAtomic(&controlUseDERPRoute, resp.Debug.DERPRoute)
			setControlAtomic(&controlTrimWGConfig, resp.Debug.TrimWGConfig)
		}
		// Temporarily (2020-06-29) support removing all but
		// discovery-supporting nodes during development, for
		// less noise.
		if Debug.OnlyDisco {
			filtered := resp.Peers[:0]
			for _, p := range resp.Peers {
				if !p.DiscoKey.IsZero() {
					filtered = append(filtered, p)
				}
			}
			resp.Peers = filtered
		}

		nm := &NetworkMap{
			NodeKey:      tailcfg.NodeKey(persist.PrivateNodeKey.Public()),
			PrivateKey:   persist.PrivateNodeKey,
			MachineKey:   machinePubKey,
			Expiry:       resp.Node.KeyExpiry,
			Name:         resp.Node.Name,
			Addresses:    resp.Node.Addresses,
			Peers:        resp.Peers,
			LocalPort:    localPort,
			User:         resp.Node.User,
			UserProfiles: make(map[tailcfg.UserID]tailcfg.UserProfile),
			Domain:       resp.Domain,
			DNS:          resp.DNSConfig,
			Hostinfo:     resp.Node.Hostinfo,
			PacketFilter: c.parsePacketFilter(resp.PacketFilter),
			DERPMap:      lastDERPMap,
			Debug:        resp.Debug,
		}
		addUserProfile := func(userID tailcfg.UserID) {
			if _, dup := nm.UserProfiles[userID]; dup {
				// Already populated it from a previous peer.
				return
			}
			if up, ok := lastUserProfile[userID]; ok {
				nm.UserProfiles[userID] = up
			}
		}
		addUserProfile(nm.User)
		for _, peer := range resp.Peers {
			addUserProfile(peer.User)
		}
		if resp.Node.MachineAuthorized {
			nm.MachineStatus = tailcfg.MachineAuthorized
		} else {
			nm.MachineStatus = tailcfg.MachineUnauthorized
		}
		if len(resp.DNS) > 0 {
			nm.DNS.Nameservers = wgIPToNetaddr(resp.DNS)
		}
		if len(resp.SearchPaths) > 0 {
			nm.DNS.Domains = resp.SearchPaths
		}
		if Debug.ProxyDNS {
			nm.DNS.Proxied = true
		}

		// Printing the netmap can be extremely verbose, but is very
		// handy for debugging. Let's limit how often we do it.
		// Code elsewhere prints netmap diffs every time, so this
		// occasional full dump, plus incremental diffs, should do
		// the job.
		now := c.timeNow()
		if now.Sub(c.lastPrintMap) >= 5*time.Minute {
			c.lastPrintMap = now
			c.logf("new network map[%d]:\n%s", i, nm.Concise())
		}

		c.mu.Lock()
		c.expiry = &nm.Expiry
		c.mu.Unlock()

		cb(nm)
	}
	if ctx.Err() != nil {
		return ctx.Err()
	}
	return nil
}

func decode(res *http.Response, v interface{}, serverKey *wgcfg.Key, mkey *wgcfg.PrivateKey) error {
	defer res.Body.Close()
	msg, err := ioutil.ReadAll(io.LimitReader(res.Body, 1<<20))
	if err != nil {
		return err
	}
	if res.StatusCode != 200 {
		return fmt.Errorf("%d: %v", res.StatusCode, string(msg))
	}
	return decodeMsg(msg, v, serverKey, mkey)
}

var debugMap, _ = strconv.ParseBool(os.Getenv("TS_DEBUG_MAP"))

var jsonEscapedZero = []byte(`\u0000`)

func (c *Direct) decodeMsg(msg []byte, v interface{}) error {
	c.mu.Lock()
	serverKey := c.serverKey
	c.mu.Unlock()

	decrypted, err := decryptMsg(msg, &serverKey, &c.machinePrivKey)
	if err != nil {
		return err
	}
	var b []byte
	if c.newDecompressor == nil {
		b = decrypted
	} else {
		decoder, err := c.newDecompressor()
		if err != nil {
			return err
		}
		defer decoder.Close()
		b, err = decoder.DecodeAll(decrypted, nil)
		if err != nil {
			return err
		}
	}
	if debugMap {
		var buf bytes.Buffer
		json.Indent(&buf, b, "", "    ")
		log.Printf("MapResponse: %s", buf.Bytes())
	}

	if bytes.Contains(b, jsonEscapedZero) {
		log.Printf("[unexpected] zero byte in controlclient.Direct.decodeMsg into %T: %q", v, b)
	}
	if err := json.Unmarshal(b, v); err != nil {
		return fmt.Errorf("response: %v", err)
	}
	return nil

}

func decodeMsg(msg []byte, v interface{}, serverKey *wgcfg.Key, mkey *wgcfg.PrivateKey) error {
	decrypted, err := decryptMsg(msg, serverKey, mkey)
	if err != nil {
		return err
	}
	if bytes.Contains(decrypted, jsonEscapedZero) {
		log.Printf("[unexpected] zero byte in controlclient decodeMsg into %T: %q", v, decrypted)
	}
	if err := json.Unmarshal(decrypted, v); err != nil {
		return fmt.Errorf("response: %v", err)
	}
	return nil
}

func decryptMsg(msg []byte, serverKey *wgcfg.Key, mkey *wgcfg.PrivateKey) ([]byte, error) {
	var nonce [24]byte
	if len(msg) < len(nonce)+1 {
		return nil, fmt.Errorf("response missing nonce, len=%d", len(msg))
	}
	copy(nonce[:], msg)
	msg = msg[len(nonce):]

	pub, pri := (*[32]byte)(serverKey), (*[32]byte)(mkey)
	decrypted, ok := box.Open(nil, msg, &nonce, pub, pri)
	if !ok {
		return nil, fmt.Errorf("cannot decrypt response (len %d + nonce %d = %d)", len(msg), len(nonce), len(msg)+len(nonce))
	}
	return decrypted, nil
}

func encode(v interface{}, serverKey *wgcfg.Key, mkey *wgcfg.PrivateKey) ([]byte, error) {
	b, err := json.Marshal(v)
	if err != nil {
		return nil, err
	}
	if debugMap {
		if _, ok := v.(tailcfg.MapRequest); ok {
			log.Printf("MapRequest: %s", b)
		}
	}
	var nonce [24]byte
	if _, err := io.ReadFull(rand.Reader, nonce[:]); err != nil {
		panic(err)
	}
	pub, pri := (*[32]byte)(serverKey), (*[32]byte)(mkey)
	msg := box.Seal(nonce[:], b, &nonce, pub, pri)
	return msg, nil
}

func loadServerKey(ctx context.Context, httpc *http.Client, serverURL string) (wgcfg.Key, error) {
	req, err := http.NewRequest("GET", serverURL+"/key", nil)
	if err != nil {
		return wgcfg.Key{}, fmt.Errorf("create control key request: %v", err)
	}
	req = req.WithContext(ctx)
	res, err := httpc.Do(req)
	if err != nil {
		return wgcfg.Key{}, fmt.Errorf("fetch control key: %v", err)
	}
	defer res.Body.Close()
	b, err := ioutil.ReadAll(io.LimitReader(res.Body, 1<<16))
	if err != nil {
		return wgcfg.Key{}, fmt.Errorf("fetch control key response: %v", err)
	}
	if res.StatusCode != 200 {
		return wgcfg.Key{}, fmt.Errorf("fetch control key: %d: %s", res.StatusCode, string(b))
	}
	key, err := wgcfg.ParseHexKey(string(b))
	if err != nil {
		return wgcfg.Key{}, fmt.Errorf("fetch control key: %v", err)
	}
	return key, nil
}

func wgIPToNetaddr(ips []wgcfg.IP) (ret []netaddr.IP) {
	for _, ip := range ips {
		nip, ok := netaddr.FromStdIP(ip.IP())
		if !ok {
			panic(fmt.Sprintf("conversion of %s from wgcfg to netaddr IP failed", ip))
		}
		ret = append(ret, nip.Unmap())
	}
	return ret
}

// Debug contains temporary internal-only debug knobs.
// They're unexported to not draw attention to them.
var Debug = initDebug()

type debug struct {
	NetMap    bool
	ProxyDNS  bool
	OnlyDisco bool
	Disco     bool
}

func initDebug() debug {
	use := os.Getenv("TS_DEBUG_USE_DISCO")
	return debug{
		NetMap:    envBool("TS_DEBUG_NETMAP"),
		ProxyDNS:  envBool("TS_DEBUG_PROXY_DNS"),
		OnlyDisco: use == "only",
		Disco:     use == "only" || use == "" || envBool("TS_DEBUG_USE_DISCO"),
	}
}

func envBool(k string) bool {
	e := os.Getenv(k)
	if e == "" {
		return false
	}
	v, err := strconv.ParseBool(e)
	if err != nil {
		panic(fmt.Sprintf("invalid non-bool %q for env var %q", e, k))
	}
	return v
}

// undeltaPeers updates mapRes.Peers to be complete based on the provided previous peer list
// and the PeersRemoved and PeersChanged fields in mapRes.
// It then also nils out the delta fields.
func undeltaPeers(mapRes *tailcfg.MapResponse, prev []*tailcfg.Node) {
	if len(mapRes.Peers) > 0 {
		// Not delta encoded.
		if !nodesSorted(mapRes.Peers) {
			log.Printf("netmap: undeltaPeers: MapResponse.Peers not sorted; sorting")
			sortNodes(mapRes.Peers)
		}
		return
	}

	var removed map[tailcfg.NodeID]bool
	if pr := mapRes.PeersRemoved; len(pr) > 0 {
		removed = make(map[tailcfg.NodeID]bool, len(pr))
		for _, id := range pr {
			removed[id] = true
		}
	}
	changed := mapRes.PeersChanged

	if len(removed) == 0 && len(changed) == 0 {
		// No changes fast path.
		mapRes.Peers = prev
		return
	}

	if !nodesSorted(changed) {
		log.Printf("netmap: undeltaPeers: MapResponse.PeersChanged not sorted; sorting")
		sortNodes(changed)
	}
	if !nodesSorted(prev) {
		// Internal error (unrelated to the network) if we get here.
		log.Printf("netmap: undeltaPeers: [unexpected] prev not sorted; sorting")
		sortNodes(prev)
	}

	newFull := make([]*tailcfg.Node, 0, len(prev)-len(removed))
	for len(prev) > 0 && len(changed) > 0 {
		pID := prev[0].ID
		cID := changed[0].ID
		if removed[pID] {
			prev = prev[1:]
			continue
		}
		switch {
		case pID < cID:
			newFull = append(newFull, prev[0])
			prev = prev[1:]
		case pID == cID:
			newFull = append(newFull, changed[0])
			prev, changed = prev[1:], changed[1:]
		case cID < pID:
			newFull = append(newFull, changed[0])
			changed = changed[1:]
		}
	}
	newFull = append(newFull, changed...)
	for _, n := range prev {
		if !removed[n.ID] {
			newFull = append(newFull, n)
		}
	}
	sortNodes(newFull)
	mapRes.Peers = newFull
	mapRes.PeersChanged = nil
	mapRes.PeersRemoved = nil
}

func nodesSorted(v []*tailcfg.Node) bool {
	for i, n := range v {
		if i > 0 && n.ID <= v[i-1].ID {
			return false
		}
	}
	return true
}

func sortNodes(v []*tailcfg.Node) {
	sort.Slice(v, func(i, j int) bool { return v[i].ID < v[j].ID })
}

func cloneNodes(v1 []*tailcfg.Node) []*tailcfg.Node {
	if v1 == nil {
		return nil
	}
	v2 := make([]*tailcfg.Node, len(v1))
	for i, n := range v1 {
		v2[i] = n.Clone()
	}
	return v2
}

// opt.Bool configs from control.
var (
	controlUseDERPRoute atomic.Value
	controlTrimWGConfig atomic.Value
)

func setControlAtomic(dst *atomic.Value, v opt.Bool) {
	old, ok := dst.Load().(opt.Bool)
	if !ok || old != v {
		dst.Store(v)
	}
}

// DERPRouteFlag reports the last reported value from control for whether
// DERP route optimization (Issue 150) should be enabled.
func DERPRouteFlag() opt.Bool {
	v, _ := controlUseDERPRoute.Load().(opt.Bool)
	return v
}

// TrimWGConfig reports the last reported value from control for whether
// we should do lazy wireguard configuration.
func TrimWGConfig() opt.Bool {
	v, _ := controlTrimWGConfig.Load().(opt.Bool)
	return v
}

// ipForwardingBroken reports whether the system's IP forwarding is disabled
// and will definitely not work for the routes provided.
//
// It should not return false positives.
func ipForwardingBroken(routes []wgcfg.CIDR) bool {
	if len(routes) == 0 {
		// Nothing to route, so no need to warn.
		return false
	}
	if runtime.GOOS != "linux" {
		// We only do subnet routing on Linux for now.
		// It might work on darwin/macOS when building from source, so
		// don't return true for other OSes. We can OS-based warnings
		// already in the admin panel.
		return false
	}
	out, err := ioutil.ReadFile("/proc/sys/net/ipv4/ip_forward")
	if err != nil {
		// Try another way.
		out, err = exec.Command("sysctl", "-n", "net.ipv4.ip_forward").Output()
	}
	if err != nil {
		// Oh well, we tried. This is just for debugging.
		// We don't want false positives.
		// TODO: maybe we want a different warning for inability to check?
		return false
	}
	return strings.TrimSpace(string(out)) == "0"
	// TODO: also check IPv6 if 'routes' contains any IPv6 routes
}