// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package controlbase

import (
	"context"
	"crypto/cipher"
	"encoding/binary"
	"errors"
	"fmt"
	"hash"
	"io"
	"net"
	"strconv"
	"time"

	"go4.org/mem"
	"golang.org/x/crypto/blake2s"
	chp "golang.org/x/crypto/chacha20poly1305"
	"golang.org/x/crypto/curve25519"
	"golang.org/x/crypto/hkdf"
	"tailscale.com/types/key"
)

const (
	// protocolName is the name of the specific instantiation of Noise
	// that the control protocol uses. This string's value is fixed by
	// the Noise spec, and shouldn't be changed unless we're updating
	// the control protocol to use a different Noise instance.
	protocolName = "Noise_IK_25519_ChaChaPoly_BLAKE2s"
	// protocolVersion is the version of the control protocol that
	// Client will use when initiating a handshake.
	//protocolVersion uint16 = 1
	// protocolVersionPrefix is the name portion of the protocol
	// name+version string that gets mixed into the handshake as a
	// prologue.
	//
	// This mixing verifies that both clients agree that they're
	// executing the control protocol at a specific version that
	// matches the advertised version in the cleartext packet header.
	protocolVersionPrefix = "Tailscale Control Protocol v"
	invalidNonce          = ^uint64(0)
)

func protocolVersionPrologue(version uint16) []byte {
	ret := make([]byte, 0, len(protocolVersionPrefix)+5) // 5 bytes is enough to encode all possible version numbers.
	ret = append(ret, protocolVersionPrefix...)
	return strconv.AppendUint(ret, uint64(version), 10)
}

// HandshakeContinuation upgrades a net.Conn to a Conn. The net.Conn
// is assumed to have already sent the client>server handshake
// initiation message.
type HandshakeContinuation func(context.Context, net.Conn) (*Conn, error)

// ClientDeferred initiates a control client handshake, returning the
// initial message to send to the server and a continuation to
// finalize the handshake.
//
// ClientDeferred is split in this way for RTT reduction: we run this
// protocol after negotiating a protocol switch from HTTP/HTTPS. If we
// completely serialized the negotiation followed by the handshake,
// we'd pay an extra RTT to transmit the handshake initiation after
// protocol switching. By splitting the handshake into an initial
// message and a continuation, we can embed the handshake initiation
// into the HTTP protocol switching request and avoid a bit of delay.
func ClientDeferred(machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (initialHandshake []byte, continueHandshake HandshakeContinuation, err error) {
	var s symmetricState
	s.Initialize()

	// prologue
	s.MixHash(protocolVersionPrologue(protocolVersion))

	// <- s
	// ...
	s.MixHash(controlKey.UntypedBytes())

	// -> e, es, s, ss
	init := mkInitiationMessage(protocolVersion)
	machineEphemeral := key.NewMachine()
	machineEphemeralPub := machineEphemeral.Public()
	copy(init.EphemeralPub(), machineEphemeralPub.UntypedBytes())
	s.MixHash(machineEphemeralPub.UntypedBytes())
	cipher, err := s.MixDH(machineEphemeral, controlKey)
	if err != nil {
		return nil, nil, fmt.Errorf("computing es: %w", err)
	}
	machineKeyPub := machineKey.Public()
	s.EncryptAndHash(cipher, init.MachinePub(), machineKeyPub.UntypedBytes())
	cipher, err = s.MixDH(machineKey, controlKey)
	if err != nil {
		return nil, nil, fmt.Errorf("computing ss: %w", err)
	}
	s.EncryptAndHash(cipher, init.Tag(), nil) // empty message payload

	cont := func(ctx context.Context, conn net.Conn) (*Conn, error) {
		return continueClientHandshake(ctx, conn, &s, machineKey, machineEphemeral, controlKey, protocolVersion)
	}
	return init[:], cont, nil
}

// Client wraps ClientDeferred and immediately invokes the returned
// continuation with conn.
//
// This is a helper for when you don't need the fancy
// continuation-style handshake, and just want to synchronously
// upgrade a net.Conn to a secure transport.
func Client(ctx context.Context, conn net.Conn, machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (*Conn, error) {
	init, cont, err := ClientDeferred(machineKey, controlKey, protocolVersion)
	if err != nil {
		return nil, err
	}
	if _, err := conn.Write(init); err != nil {
		return nil, err
	}
	return cont(ctx, conn)
}

func continueClientHandshake(ctx context.Context, conn net.Conn, s *symmetricState, machineKey, machineEphemeral key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (*Conn, error) {
	// No matter what, this function can only run once per s. Ensure
	// attempted reuse causes a panic.
	defer func() {
		s.finished = true
	}()

	if deadline, ok := ctx.Deadline(); ok {
		if err := conn.SetDeadline(deadline); err != nil {
			return nil, fmt.Errorf("setting conn deadline: %w", err)
		}
		defer func() {
			conn.SetDeadline(time.Time{})
		}()
	}

	// Read in the payload and look for errors/protocol violations from the server.
	var resp responseMessage
	if _, err := io.ReadFull(conn, resp.Header()); err != nil {
		return nil, fmt.Errorf("reading response header: %w", err)
	}
	if resp.Type() != msgTypeResponse {
		if resp.Type() != msgTypeError {
			return nil, fmt.Errorf("unexpected response message type %d", resp.Type())
		}
		msg := make([]byte, resp.Length())
		if _, err := io.ReadFull(conn, msg); err != nil {
			return nil, err
		}
		return nil, fmt.Errorf("server error: %q", msg)
	}
	if resp.Length() != len(resp.Payload()) {
		return nil, fmt.Errorf("wrong length %d received for handshake response", resp.Length())
	}
	if _, err := io.ReadFull(conn, resp.Payload()); err != nil {
		return nil, err
	}

	// <- e, ee, se
	controlEphemeralPub := key.MachinePublicFromRaw32(mem.B(resp.EphemeralPub()))
	s.MixHash(controlEphemeralPub.UntypedBytes())
	if _, err := s.MixDH(machineEphemeral, controlEphemeralPub); err != nil {
		return nil, fmt.Errorf("computing ee: %w", err)
	}
	cipher, err := s.MixDH(machineKey, controlEphemeralPub)
	if err != nil {
		return nil, fmt.Errorf("computing se: %w", err)
	}
	if err := s.DecryptAndHash(cipher, nil, resp.Tag()); err != nil {
		return nil, fmt.Errorf("decrypting payload: %w", err)
	}

	c1, c2, err := s.Split()
	if err != nil {
		return nil, fmt.Errorf("finalizing handshake: %w", err)
	}

	c := &Conn{
		conn:          conn,
		version:       protocolVersion,
		peer:          controlKey,
		handshakeHash: s.h,
		tx: txState{
			cipher: c1,
		},
		rx: rxState{
			cipher: c2,
		},
	}
	return c, nil
}

// Server initiates a control server handshake, returning the resulting
// control connection.
//
// optionalInit can be the client's initial handshake message as
// returned by ClientDeferred, or nil in which case the initial
// message is read from conn.
//
// The context deadline, if any, covers the entire handshaking
// process.
func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate, optionalInit []byte) (*Conn, error) {
	if deadline, ok := ctx.Deadline(); ok {
		if err := conn.SetDeadline(deadline); err != nil {
			return nil, fmt.Errorf("setting conn deadline: %w", err)
		}
		defer func() {
			conn.SetDeadline(time.Time{})
		}()
	}

	// Deliberately does not support formatting, so that we don't echo
	// attacker-controlled input back to them.
	sendErr := func(msg string) error {
		if len(msg) >= 1<<16 {
			msg = msg[:1<<16]
		}
		var hdr [headerLen]byte
		hdr[0] = msgTypeError
		binary.BigEndian.PutUint16(hdr[1:3], uint16(len(msg)))
		if _, err := conn.Write(hdr[:]); err != nil {
			return fmt.Errorf("sending %q error to client: %w", msg, err)
		}
		if _, err := io.WriteString(conn, msg); err != nil {
			return fmt.Errorf("sending %q error to client: %w", msg, err)
		}
		return fmt.Errorf("refused client handshake: %q", msg)
	}

	var s symmetricState
	s.Initialize()

	var init initiationMessage
	if optionalInit != nil {
		if len(optionalInit) != len(init) {
			return nil, sendErr("wrong handshake initiation size")
		}
		copy(init[:], optionalInit)
	} else if _, err := io.ReadFull(conn, init.Header()); err != nil {
		return nil, err
	}
	// Just a rename to make it more obvious what the value is. In the
	// current implementation we don't need to block any protocol
	// versions at this layer, it's safe to let the handshake proceed
	// and then let the caller make decisions based on the agreed-upon
	// protocol version.
	clientVersion := init.Version()
	if init.Type() != msgTypeInitiation {
		return nil, sendErr("unexpected handshake message type")
	}
	if init.Length() != len(init.Payload()) {
		return nil, sendErr("wrong handshake initiation length")
	}
	// if optionalInit was provided, we have the payload already.
	if optionalInit == nil {
		if _, err := io.ReadFull(conn, init.Payload()); err != nil {
			return nil, err
		}
	}

	// prologue. Can only do this once we at least think the client is
	// handshaking using a supported version.
	s.MixHash(protocolVersionPrologue(clientVersion))

	// <- s
	// ...
	controlKeyPub := controlKey.Public()
	s.MixHash(controlKeyPub.UntypedBytes())

	// -> e, es, s, ss
	machineEphemeralPub := key.MachinePublicFromRaw32(mem.B(init.EphemeralPub()))
	s.MixHash(machineEphemeralPub.UntypedBytes())
	cipher, err := s.MixDH(controlKey, machineEphemeralPub)
	if err != nil {
		return nil, fmt.Errorf("computing es: %w", err)
	}
	var machineKeyBytes [32]byte
	if err := s.DecryptAndHash(cipher, machineKeyBytes[:], init.MachinePub()); err != nil {
		return nil, fmt.Errorf("decrypting machine key: %w", err)
	}
	machineKey := key.MachinePublicFromRaw32(mem.B(machineKeyBytes[:]))
	cipher, err = s.MixDH(controlKey, machineKey)
	if err != nil {
		return nil, fmt.Errorf("computing ss: %w", err)
	}
	if err := s.DecryptAndHash(cipher, nil, init.Tag()); err != nil {
		return nil, fmt.Errorf("decrypting initiation tag: %w", err)
	}

	// <- e, ee, se
	resp := mkResponseMessage()
	controlEphemeral := key.NewMachine()
	controlEphemeralPub := controlEphemeral.Public()
	copy(resp.EphemeralPub(), controlEphemeralPub.UntypedBytes())
	s.MixHash(controlEphemeralPub.UntypedBytes())
	if _, err := s.MixDH(controlEphemeral, machineEphemeralPub); err != nil {
		return nil, fmt.Errorf("computing ee: %w", err)
	}
	cipher, err = s.MixDH(controlEphemeral, machineKey)
	if err != nil {
		return nil, fmt.Errorf("computing se: %w", err)
	}
	s.EncryptAndHash(cipher, resp.Tag(), nil) // empty message payload

	c1, c2, err := s.Split()
	if err != nil {
		return nil, fmt.Errorf("finalizing handshake: %w", err)
	}

	if _, err := conn.Write(resp[:]); err != nil {
		return nil, err
	}

	c := &Conn{
		conn:          conn,
		version:       clientVersion,
		peer:          machineKey,
		handshakeHash: s.h,
		tx: txState{
			cipher: c2,
		},
		rx: rxState{
			cipher: c1,
		},
	}
	return c, nil
}

// symmetricState contains the state of an in-flight handshake.
type symmetricState struct {
	finished bool

	h  [blake2s.Size]byte // hash of currently-processed handshake state
	ck [blake2s.Size]byte // chaining key used to construct session keys at the end of the handshake
}

func (s *symmetricState) checkFinished() {
	if s.finished {
		panic("attempted to use symmetricState after Split was called")
	}
}

// Initialize sets s to the initial handshake state, prior to
// processing any handshake messages.
func (s *symmetricState) Initialize() {
	s.checkFinished()
	s.h = blake2s.Sum256([]byte(protocolName))
	s.ck = s.h
}

// MixHash updates s.h to be BLAKE2s(s.h || data), where || is
// concatenation.
func (s *symmetricState) MixHash(data []byte) {
	s.checkFinished()
	h := newBLAKE2s()
	h.Write(s.h[:])
	h.Write(data)
	h.Sum(s.h[:0])
}

// MixDH updates s.ck with the result of X25519(priv, pub) and returns
// a singleUseCHP that can be used to encrypt or decrypt handshake
// data.
//
// MixDH corresponds to MixKey(X25519(...))) in the spec. Implementing
// it as a single function allows for strongly-typed arguments that
// reduce the risk of error in the caller (e.g. invoking X25519 with
// two private keys, or two public keys), and thus producing the wrong
// calculation.
func (s *symmetricState) MixDH(priv key.MachinePrivate, pub key.MachinePublic) (*singleUseCHP, error) {
	s.checkFinished()
	keyData, err := curve25519.X25519(priv.UntypedBytes(), pub.UntypedBytes())
	if err != nil {
		return nil, fmt.Errorf("computing X25519: %w", err)
	}

	r := hkdf.New(newBLAKE2s, keyData, s.ck[:], nil)
	if _, err := io.ReadFull(r, s.ck[:]); err != nil {
		return nil, fmt.Errorf("extracting ck: %w", err)
	}
	var k [chp.KeySize]byte
	if _, err := io.ReadFull(r, k[:]); err != nil {
		return nil, fmt.Errorf("extracting k: %w", err)
	}
	return newSingleUseCHP(k), nil
}

// EncryptAndHash encrypts plaintext into ciphertext (which must be
// the correct size to hold the encrypted plaintext) using cipher,
// mixes the ciphertext into s.h, and returns the ciphertext.
func (s *symmetricState) EncryptAndHash(cipher *singleUseCHP, ciphertext, plaintext []byte) {
	s.checkFinished()
	if len(ciphertext) != len(plaintext)+chp.Overhead {
		panic("ciphertext is wrong size for given plaintext")
	}
	ret := cipher.Seal(ciphertext[:0], plaintext, s.h[:])
	s.MixHash(ret)
}

// DecryptAndHash decrypts the given ciphertext into plaintext (which
// must be the correct size to hold the decrypted ciphertext) using
// cipher. If decryption is successful, it mixes the ciphertext into
// s.h.
func (s *symmetricState) DecryptAndHash(cipher *singleUseCHP, plaintext, ciphertext []byte) error {
	s.checkFinished()
	if len(ciphertext) != len(plaintext)+chp.Overhead {
		return errors.New("plaintext is wrong size for given ciphertext")
	}
	if _, err := cipher.Open(plaintext[:0], ciphertext, s.h[:]); err != nil {
		return err
	}
	s.MixHash(ciphertext)
	return nil
}

// Split returns two ChaCha20Poly1305 ciphers with keys derived from
// the current handshake state. Methods on s cannot be used again
// after calling Split.
func (s *symmetricState) Split() (c1, c2 cipher.AEAD, err error) {
	s.finished = true

	var k1, k2 [chp.KeySize]byte
	r := hkdf.New(newBLAKE2s, nil, s.ck[:], nil)
	if _, err := io.ReadFull(r, k1[:]); err != nil {
		return nil, nil, fmt.Errorf("extracting k1: %w", err)
	}
	if _, err := io.ReadFull(r, k2[:]); err != nil {
		return nil, nil, fmt.Errorf("extracting k2: %w", err)
	}
	c1, err = chp.New(k1[:])
	if err != nil {
		return nil, nil, fmt.Errorf("constructing AEAD c1: %w", err)
	}
	c2, err = chp.New(k2[:])
	if err != nil {
		return nil, nil, fmt.Errorf("constructing AEAD c2: %w", err)
	}
	return c1, c2, nil
}

// newBLAKE2s returns a hash.Hash implementing BLAKE2s, or panics on
// error.
func newBLAKE2s() hash.Hash {
	h, err := blake2s.New256(nil)
	if err != nil {
		// Should never happen, errors only happen when using BLAKE2s
		// in MAC mode with a key.
		panic(err)
	}
	return h
}

// newCHP returns a cipher.AEAD implementing ChaCha20Poly1305, or
// panics on error.
func newCHP(key [chp.KeySize]byte) cipher.AEAD {
	aead, err := chp.New(key[:])
	if err != nil {
		// Can only happen if we passed a key of the wrong length. The
		// function signature prevents that.
		panic(err)
	}
	return aead
}

// singleUseCHP is an instance of ChaCha20Poly1305 that can be used
// only once, either for encrypting or decrypting, but not both. The
// chosen operation is always executed with an all-zeros
// nonce. Subsequent calls to either Seal or Open panic.
type singleUseCHP struct {
	c cipher.AEAD
}

func newSingleUseCHP(key [chp.KeySize]byte) *singleUseCHP {
	return &singleUseCHP{newCHP(key)}
}

func (c *singleUseCHP) Seal(dst, plaintext, additionalData []byte) []byte {
	if c.c == nil {
		panic("Attempted reuse of singleUseAEAD")
	}
	cipher := c.c
	c.c = nil
	var nonce [chp.NonceSize]byte
	return cipher.Seal(dst, nonce[:], plaintext, additionalData)
}

func (c *singleUseCHP) Open(dst, ciphertext, additionalData []byte) ([]byte, error) {
	if c.c == nil {
		panic("Attempted reuse of singleUseAEAD")
	}
	cipher := c.c
	c.c = nil
	var nonce [chp.NonceSize]byte
	return cipher.Open(dst, nonce[:], ciphertext, additionalData)
}