mirror of
https://github.com/tailscale/tailscale.git
synced 2025-02-21 12:28:39 +00:00
control/noise: adjust implementation to match revised spec.
Signed-off-by: David Anderson <danderson@tailscale.com>
This commit is contained in:
parent
89a68a4c22
commit
0b392dbaf7
@ -24,9 +24,9 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
maxPlaintextSize = 4096
|
maxMessageSize = 4096
|
||||||
maxCiphertextSize = maxPlaintextSize + poly1305.TagSize
|
maxCiphertextSize = maxMessageSize - headerLen
|
||||||
maxPacketSize = maxCiphertextSize + 2 // ciphertext + length header
|
maxPlaintextSize = maxCiphertextSize - poly1305.TagSize
|
||||||
)
|
)
|
||||||
|
|
||||||
// A Conn is a secured Noise connection. It implements the net.Conn
|
// A Conn is a secured Noise connection. It implements the net.Conn
|
||||||
@ -35,6 +35,7 @@ const (
|
|||||||
// fail.
|
// fail.
|
||||||
type Conn struct {
|
type Conn struct {
|
||||||
conn net.Conn
|
conn net.Conn
|
||||||
|
version int
|
||||||
peer key.Public
|
peer key.Public
|
||||||
handshakeHash [blake2s.Size]byte
|
handshakeHash [blake2s.Size]byte
|
||||||
rx rxState
|
rx rxState
|
||||||
@ -46,7 +47,7 @@ type rxState struct {
|
|||||||
sync.Mutex
|
sync.Mutex
|
||||||
cipher cipher.AEAD
|
cipher cipher.AEAD
|
||||||
nonce [chp.NonceSize]byte
|
nonce [chp.NonceSize]byte
|
||||||
buf [maxPacketSize]byte
|
buf [maxMessageSize]byte
|
||||||
n int // number of valid bytes in buf
|
n int // number of valid bytes in buf
|
||||||
next int // offset of next undecrypted packet
|
next int // offset of next undecrypted packet
|
||||||
plaintext []byte // slice into buf of decrypted bytes
|
plaintext []byte // slice into buf of decrypted bytes
|
||||||
@ -57,10 +58,14 @@ type txState struct {
|
|||||||
sync.Mutex
|
sync.Mutex
|
||||||
cipher cipher.AEAD
|
cipher cipher.AEAD
|
||||||
nonce [chp.NonceSize]byte
|
nonce [chp.NonceSize]byte
|
||||||
buf [maxPacketSize]byte
|
buf [maxMessageSize]byte
|
||||||
err error // records the first partial write error for all future calls
|
err error // records the first partial write error for all future calls
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Conn) ProtocolVersion() int {
|
||||||
|
return c.version
|
||||||
|
}
|
||||||
|
|
||||||
// HandshakeHash returns the Noise handshake hash for the connection,
|
// HandshakeHash returns the Noise handshake hash for the connection,
|
||||||
// which can be used to bind other messages to this connection
|
// which can be used to bind other messages to this connection
|
||||||
// (i.e. to ensure that the message wasn't replayed from a different
|
// (i.e. to ensure that the message wasn't replayed from a different
|
||||||
@ -84,7 +89,7 @@ func validNonce(nonce []byte) bool {
|
|||||||
// bytes. Returns a slice of the available bytes in rxBuf, or an
|
// bytes. Returns a slice of the available bytes in rxBuf, or an
|
||||||
// error if fewer than total bytes are available.
|
// error if fewer than total bytes are available.
|
||||||
func (c *Conn) readNLocked(total int) ([]byte, error) {
|
func (c *Conn) readNLocked(total int) ([]byte, error) {
|
||||||
if total > maxPacketSize {
|
if total > maxMessageSize {
|
||||||
return nil, errReadTooBig{total}
|
return nil, errReadTooBig{total}
|
||||||
}
|
}
|
||||||
for {
|
for {
|
||||||
@ -100,10 +105,20 @@ func (c *Conn) readNLocked(total int) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// decryptLocked decrypts ciphertext in-place and sets c.rx.plaintext
|
// decryptLocked decrypts message (which is header+ciphertext)
|
||||||
// to the decrypted bytes. Returns an error if the cipher is exhausted
|
// in-place and sets c.rx.plaintext to the decrypted bytes. Returns an
|
||||||
// (i.e. can no longer be used safely) or decryption fails.
|
// error if the cipher is exhausted (i.e. can no longer be used
|
||||||
func (c *Conn) decryptLocked(ciphertext []byte) (err error) {
|
// safely) or decryption fails.
|
||||||
|
func (c *Conn) decryptLocked(msg []byte) (err error) {
|
||||||
|
if hdrVersion(msg) != c.version {
|
||||||
|
return fmt.Errorf("received message with unexpected protocol version %d, want %d", hdrVersion(msg), c.version)
|
||||||
|
}
|
||||||
|
if hdrType(msg) != msgTypeRecord {
|
||||||
|
return fmt.Errorf("received message with unexpected type %d, want %d", hdrType(msg), msgTypeRecord)
|
||||||
|
}
|
||||||
|
// length was already handled in caller to size msg.
|
||||||
|
ciphertext := msg[headerLen:]
|
||||||
|
|
||||||
if !validNonce(c.rx.nonce[:]) {
|
if !validNonce(c.rx.nonce[:]) {
|
||||||
return errCipherExhausted{}
|
return errCipherExhausted{}
|
||||||
}
|
}
|
||||||
@ -124,8 +139,8 @@ func (c *Conn) decryptLocked(ciphertext []byte) (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// encryptLocked encrypts plaintext into c.tx.buf (including the
|
// encryptLocked encrypts plaintext into c.tx.buf (including the
|
||||||
// 2-byte length header) and returns a slice of the ciphertext, or an
|
// packet header) and returns a slice of the ciphertext, or an error
|
||||||
// error if the cipher is exhausted (i.e. can no longer be used safely).
|
// if the cipher is exhausted (i.e. can no longer be used safely).
|
||||||
func (c *Conn) encryptLocked(plaintext []byte) ([]byte, error) {
|
func (c *Conn) encryptLocked(plaintext []byte) ([]byte, error) {
|
||||||
if !validNonce(c.tx.nonce[:]) {
|
if !validNonce(c.tx.nonce[:]) {
|
||||||
// Received 2^64-1 messages on this cipher state. Connection
|
// Received 2^64-1 messages on this cipher state. Connection
|
||||||
@ -133,8 +148,8 @@ func (c *Conn) encryptLocked(plaintext []byte) ([]byte, error) {
|
|||||||
return nil, errCipherExhausted{}
|
return nil, errCipherExhausted{}
|
||||||
}
|
}
|
||||||
|
|
||||||
binary.BigEndian.PutUint16(c.tx.buf[:2], uint16(len(plaintext)+poly1305.TagSize))
|
setHeader(c.tx.buf[:5], protocolVersion, msgTypeRecord, len(plaintext)+poly1305.TagSize)
|
||||||
ret := c.tx.cipher.Seal(c.tx.buf[:2], c.tx.nonce[:], plaintext, nil)
|
ret := c.tx.cipher.Seal(c.tx.buf[:5], c.tx.nonce[:], plaintext, nil)
|
||||||
|
|
||||||
// Safe to increment the nonce here, because we checked for nonce
|
// Safe to increment the nonce here, because we checked for nonce
|
||||||
// wraparound above.
|
// wraparound above.
|
||||||
@ -143,18 +158,18 @@ func (c *Conn) encryptLocked(plaintext []byte) ([]byte, error) {
|
|||||||
return ret, nil
|
return ret, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// wholeCiphertextLocked returns a slice of one whole Noise frame from
|
// wholeMessageLocked returns a slice of one whole Noise transport
|
||||||
// c.rx.buf, if one whole ciphertext is available, and advances the
|
// message from c.rx.buf, if one whole message is available, and
|
||||||
// read state to the next Noise frame in the buffer. Returns nil
|
// advances the read state to the next Noise message in the
|
||||||
// without advancing read state if there's not one whole ciphertext in
|
// buffer. Returns nil without advancing read state if there isn't one
|
||||||
// c.rx.buf.
|
// whole message in c.rx.buf.
|
||||||
func (c *Conn) wholeCiphertextLocked() []byte {
|
func (c *Conn) wholeMessageLocked() []byte {
|
||||||
available := c.rx.n - c.rx.next
|
available := c.rx.n - c.rx.next
|
||||||
if available < 2 {
|
if available < headerLen {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
bs := c.rx.buf[c.rx.next:c.rx.n]
|
bs := c.rx.buf[c.rx.next:c.rx.n]
|
||||||
totalSize := int(binary.BigEndian.Uint16(bs[:2])) + 2
|
totalSize := hdrLen(bs) + headerLen
|
||||||
if len(bs) < totalSize {
|
if len(bs) < totalSize {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -162,16 +177,16 @@ func (c *Conn) wholeCiphertextLocked() []byte {
|
|||||||
return bs[:totalSize]
|
return bs[:totalSize]
|
||||||
}
|
}
|
||||||
|
|
||||||
// decryptOneLocked decrypts one Noise frame, reading from c.conn as needed,
|
// decryptOneLocked decrypts one Noise transport message, reading from
|
||||||
// and sets c.rx.plaintext to point to the decrypted
|
// c.conn as needed, and sets c.rx.plaintext to point to the decrypted
|
||||||
// bytes. c.rx.plaintext is only valid if err == nil.
|
// bytes. c.rx.plaintext is only valid if err == nil.
|
||||||
func (c *Conn) decryptOneLocked() error {
|
func (c *Conn) decryptOneLocked() error {
|
||||||
c.rx.plaintext = nil
|
c.rx.plaintext = nil
|
||||||
|
|
||||||
// Fast path: do we have one whole ciphertext frame buffered
|
// Fast path: do we have one whole ciphertext frame buffered
|
||||||
// already?
|
// already?
|
||||||
if bs := c.wholeCiphertextLocked(); bs != nil {
|
if bs := c.wholeMessageLocked(); bs != nil {
|
||||||
return c.decryptLocked(bs[2:])
|
return c.decryptLocked(bs)
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.rx.next != 0 {
|
if c.rx.next != 0 {
|
||||||
@ -183,18 +198,20 @@ func (c *Conn) decryptOneLocked() error {
|
|||||||
c.rx.next = 0
|
c.rx.next = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
bs, err := c.readNLocked(2)
|
bs, err := c.readNLocked(headerLen)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
totalLen := int(binary.BigEndian.Uint16(bs[:2])) + 2
|
// The rest of the header (besides the length field) gets verified
|
||||||
bs, err = c.readNLocked(totalLen)
|
// in decryptLocked, not here.
|
||||||
|
messageLen := headerLen + hdrLen(bs)
|
||||||
|
bs, err = c.readNLocked(messageLen)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
bs = bs[:messageLen]
|
||||||
|
|
||||||
c.rx.next = totalLen
|
c.rx.next = len(bs)
|
||||||
bs = bs[2:totalLen]
|
|
||||||
|
|
||||||
return c.decryptLocked(bs)
|
return c.decryptLocked(bs)
|
||||||
}
|
}
|
||||||
|
@ -12,6 +12,7 @@ import (
|
|||||||
"hash"
|
"hash"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/crypto/blake2s"
|
"golang.org/x/crypto/blake2s"
|
||||||
@ -23,15 +24,32 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
// protocolName is the name of the specific instantiation of the
|
||||||
|
// Noise protocol we're using. Each field is defined in the Noise
|
||||||
|
// spec, and shouldn't be changed unless we're switching to a
|
||||||
|
// different Noise protocol instance.
|
||||||
protocolName = "Noise_IK_25519_ChaChaPoly_BLAKE2s"
|
protocolName = "Noise_IK_25519_ChaChaPoly_BLAKE2s"
|
||||||
// protocolVersion is the version string that gets included as the
|
// protocolVersion is the version of the Tailscale base
|
||||||
// Noise "prologue" in the handshake. It exists so that we can
|
// protocol that Client will use when initiating a handshake.
|
||||||
// ensure that peer have agreed on the protocol version they're
|
protocolVersion = 1
|
||||||
// executing, to defeat some MITM protocol downgrade attacks.
|
// protocolVersionPrefix is the name portion of the protocol
|
||||||
protocolVersion = "Tailscale Control Protocol v1"
|
// name+version string that gets mixed into the Noise handshake as
|
||||||
|
// a prologue.
|
||||||
|
//
|
||||||
|
// This mixing verifies that both clients agree that
|
||||||
|
// they're executing the Tailscale control protocol at a specific
|
||||||
|
// version that matches the advertised version in the cleartext
|
||||||
|
// packet header.
|
||||||
|
protocolVersionPrefix = "Tailscale Control Protocol v"
|
||||||
invalidNonce = ^uint64(0)
|
invalidNonce = ^uint64(0)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func protocolVersionPrologue(version int) []byte {
|
||||||
|
ret := make([]byte, 0, len(protocolVersionPrefix)+5) // 5 bytes is enough to encode all possible version numbers.
|
||||||
|
ret = append(ret, protocolVersionPrefix...)
|
||||||
|
return strconv.AppendUint(ret, uint64(version), 10)
|
||||||
|
}
|
||||||
|
|
||||||
// Client initiates a Noise client handshake, returning the resulting
|
// Client initiates a Noise client handshake, returning the resulting
|
||||||
// Noise connection.
|
// Noise connection.
|
||||||
//
|
//
|
||||||
@ -50,15 +68,18 @@ func Client(ctx context.Context, conn net.Conn, machineKey key.Private, controlK
|
|||||||
var s symmetricState
|
var s symmetricState
|
||||||
s.Initialize()
|
s.Initialize()
|
||||||
|
|
||||||
|
// prologue
|
||||||
|
s.MixHash(protocolVersionPrologue(protocolVersion))
|
||||||
|
|
||||||
// <- s
|
// <- s
|
||||||
// ...
|
// ...
|
||||||
s.MixHash(controlKey[:])
|
s.MixHash(controlKey[:])
|
||||||
|
|
||||||
// -> e, es, s, ss
|
// -> e, es, s, ss
|
||||||
var init initiationMessage
|
init := mkInitiationMessage()
|
||||||
machineEphemeral := key.NewPrivate()
|
machineEphemeral := key.NewPrivate()
|
||||||
machineEphemeralPub := machineEphemeral.Public()
|
machineEphemeralPub := machineEphemeral.Public()
|
||||||
copy(init.MachineEphemeralPub(), machineEphemeralPub[:])
|
copy(init.EphemeralPub(), machineEphemeralPub[:])
|
||||||
s.MixHash(machineEphemeralPub[:])
|
s.MixHash(machineEphemeralPub[:])
|
||||||
if err := s.MixDH(machineEphemeral, controlKey); err != nil {
|
if err := s.MixDH(machineEphemeral, controlKey); err != nil {
|
||||||
return nil, fmt.Errorf("computing es: %w", err)
|
return nil, fmt.Errorf("computing es: %w", err)
|
||||||
@ -74,14 +95,34 @@ func Client(ctx context.Context, conn net.Conn, machineKey key.Private, controlK
|
|||||||
return nil, fmt.Errorf("writing initiation: %w", err)
|
return nil, fmt.Errorf("writing initiation: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// <- e, ee, se
|
// Read in the payload and look for errors/protocol violations from the server.
|
||||||
var resp responseMessage
|
var resp responseMessage
|
||||||
if _, err := io.ReadFull(conn, resp[:]); err != nil {
|
if _, err := io.ReadFull(conn, resp.Header()); err != nil {
|
||||||
return nil, fmt.Errorf("reading response: %w", err)
|
return nil, fmt.Errorf("reading response header: %w", err)
|
||||||
|
}
|
||||||
|
if resp.Version() != protocolVersion {
|
||||||
|
return nil, fmt.Errorf("unexpected version %d from server, want %d", resp.Version(), protocolVersion)
|
||||||
|
}
|
||||||
|
if resp.Type() != msgTypeResponse {
|
||||||
|
if resp.Type() != msgTypeError {
|
||||||
|
return nil, fmt.Errorf("unexpected response message type %d", resp.Type())
|
||||||
|
}
|
||||||
|
msg := make([]byte, resp.Length())
|
||||||
|
if _, err := io.ReadFull(conn, msg); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("server error: %s", string(msg))
|
||||||
|
}
|
||||||
|
if resp.Length() != len(resp.Payload()) {
|
||||||
|
return nil, fmt.Errorf("wrong length %d received for handshake response", resp.Length())
|
||||||
|
}
|
||||||
|
if _, err := io.ReadFull(conn, resp.Payload()); err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// <- e, ee, se
|
||||||
var controlEphemeralPub key.Public
|
var controlEphemeralPub key.Public
|
||||||
copy(controlEphemeralPub[:], resp.ControlEphemeralPub())
|
copy(controlEphemeralPub[:], resp.EphemeralPub())
|
||||||
s.MixHash(controlEphemeralPub[:])
|
s.MixHash(controlEphemeralPub[:])
|
||||||
if err := s.MixDH(machineEphemeral, controlEphemeralPub); err != nil {
|
if err := s.MixDH(machineEphemeral, controlEphemeralPub); err != nil {
|
||||||
return nil, fmt.Errorf("computing ee: %w", err)
|
return nil, fmt.Errorf("computing ee: %w", err)
|
||||||
@ -100,6 +141,7 @@ func Client(ctx context.Context, conn net.Conn, machineKey key.Private, controlK
|
|||||||
|
|
||||||
return &Conn{
|
return &Conn{
|
||||||
conn: conn,
|
conn: conn,
|
||||||
|
version: protocolVersion,
|
||||||
peer: controlKey,
|
peer: controlKey,
|
||||||
handshakeHash: s.h,
|
handshakeHash: s.h,
|
||||||
tx: txState{
|
tx: txState{
|
||||||
@ -126,22 +168,55 @@ func Server(ctx context.Context, conn net.Conn, controlKey key.Private) (*Conn,
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Deliberately does not support formatting, so that we don't echo
|
||||||
|
// attacker-controlled input back to them.
|
||||||
|
sendErr := func(msg string) error {
|
||||||
|
if len(msg) >= 1<<16 {
|
||||||
|
msg = msg[:1<<16]
|
||||||
|
}
|
||||||
|
var hdr [headerLen]byte
|
||||||
|
setHeader(hdr[:], protocolVersion, msgTypeError, len(msg))
|
||||||
|
if _, err := conn.Write(hdr[:]); err != nil {
|
||||||
|
return fmt.Errorf("sending %q error to client: %w", msg, err)
|
||||||
|
}
|
||||||
|
if _, err := conn.Write([]byte(msg)); err != nil {
|
||||||
|
return fmt.Errorf("sending %q error to client: %w", msg, err)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("refused client handshake: %s", msg)
|
||||||
|
}
|
||||||
|
|
||||||
var s symmetricState
|
var s symmetricState
|
||||||
s.Initialize()
|
s.Initialize()
|
||||||
|
|
||||||
|
var init initiationMessage
|
||||||
|
if _, err := io.ReadFull(conn, init.Header()); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if init.Version() != protocolVersion {
|
||||||
|
return nil, sendErr("unsupported protocol version")
|
||||||
|
}
|
||||||
|
if init.Type() != msgTypeInitiation {
|
||||||
|
return nil, sendErr("unexpected handshake message type")
|
||||||
|
}
|
||||||
|
if init.Length() != len(init.Payload()) {
|
||||||
|
return nil, sendErr("wrong handshake initiation length")
|
||||||
|
}
|
||||||
|
if _, err := io.ReadFull(conn, init.Payload()); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// prologue. Can only do this once we at least think the client is
|
||||||
|
// handshaking using a supported version.
|
||||||
|
s.MixHash(protocolVersionPrologue(protocolVersion))
|
||||||
|
|
||||||
// <- s
|
// <- s
|
||||||
// ...
|
// ...
|
||||||
controlKeyPub := controlKey.Public()
|
controlKeyPub := controlKey.Public()
|
||||||
s.MixHash(controlKeyPub[:])
|
s.MixHash(controlKeyPub[:])
|
||||||
|
|
||||||
// -> e, es, s, ss
|
// -> e, es, s, ss
|
||||||
var init initiationMessage
|
|
||||||
if _, err := io.ReadFull(conn, init[:]); err != nil {
|
|
||||||
return nil, fmt.Errorf("reading initiation: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var machineEphemeralPub key.Public
|
var machineEphemeralPub key.Public
|
||||||
copy(machineEphemeralPub[:], init.MachineEphemeralPub())
|
copy(machineEphemeralPub[:], init.EphemeralPub())
|
||||||
s.MixHash(machineEphemeralPub[:])
|
s.MixHash(machineEphemeralPub[:])
|
||||||
if err := s.MixDH(controlKey, machineEphemeralPub); err != nil {
|
if err := s.MixDH(controlKey, machineEphemeralPub); err != nil {
|
||||||
return nil, fmt.Errorf("computing es: %w", err)
|
return nil, fmt.Errorf("computing es: %w", err)
|
||||||
@ -158,10 +233,10 @@ func Server(ctx context.Context, conn net.Conn, controlKey key.Private) (*Conn,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// <- e, ee, se
|
// <- e, ee, se
|
||||||
var resp responseMessage
|
resp := mkResponseMessage()
|
||||||
controlEphemeral := key.NewPrivate()
|
controlEphemeral := key.NewPrivate()
|
||||||
controlEphemeralPub := controlEphemeral.Public()
|
controlEphemeralPub := controlEphemeral.Public()
|
||||||
copy(resp.ControlEphemeralPub(), controlEphemeralPub[:])
|
copy(resp.EphemeralPub(), controlEphemeralPub[:])
|
||||||
s.MixHash(controlEphemeralPub[:])
|
s.MixHash(controlEphemeralPub[:])
|
||||||
if err := s.MixDH(controlEphemeral, machineEphemeralPub); err != nil {
|
if err := s.MixDH(controlEphemeral, machineEphemeralPub); err != nil {
|
||||||
return nil, fmt.Errorf("computing ee: %w", err)
|
return nil, fmt.Errorf("computing ee: %w", err)
|
||||||
@ -182,6 +257,7 @@ func Server(ctx context.Context, conn net.Conn, controlKey key.Private) (*Conn,
|
|||||||
|
|
||||||
return &Conn{
|
return &Conn{
|
||||||
conn: conn,
|
conn: conn,
|
||||||
|
version: protocolVersion,
|
||||||
peer: machineKey,
|
peer: machineKey,
|
||||||
handshakeHash: s.h,
|
handshakeHash: s.h,
|
||||||
tx: txState{
|
tx: txState{
|
||||||
@ -193,21 +269,6 @@ func Server(ctx context.Context, conn net.Conn, controlKey key.Private) (*Conn,
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// initiationMessage is the Noise protocol message sent from a client
|
|
||||||
// machine to a control server.
|
|
||||||
type initiationMessage [96]byte
|
|
||||||
|
|
||||||
func (m *initiationMessage) MachineEphemeralPub() []byte { return m[:32] }
|
|
||||||
func (m *initiationMessage) MachinePub() []byte { return m[32:80] }
|
|
||||||
func (m *initiationMessage) Tag() []byte { return m[80:] }
|
|
||||||
|
|
||||||
// responseMessage is the Noise protocol message sent from a control
|
|
||||||
// server to a client machine.
|
|
||||||
type responseMessage [48]byte
|
|
||||||
|
|
||||||
func (m *responseMessage) ControlEphemeralPub() []byte { return m[:32] }
|
|
||||||
func (m *responseMessage) Tag() []byte { return m[32:] }
|
|
||||||
|
|
||||||
// symmetricState is the SymmetricState object from the Noise protocol
|
// symmetricState is the SymmetricState object from the Noise protocol
|
||||||
// spec. It contains all the symmetric cipher state of an in-flight
|
// spec. It contains all the symmetric cipher state of an in-flight
|
||||||
// handshake. Field names match the variable names in the spec.
|
// handshake. Field names match the variable names in the spec.
|
||||||
@ -232,7 +293,6 @@ func (s *symmetricState) Initialize() {
|
|||||||
s.k = [chp.KeySize]byte{}
|
s.k = [chp.KeySize]byte{}
|
||||||
s.n = invalidNonce
|
s.n = invalidNonce
|
||||||
s.mixer = newBLAKE2s()
|
s.mixer = newBLAKE2s()
|
||||||
s.MixHash([]byte(protocolVersion))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// MixHash updates s.h to be BLAKE2s(s.h || data), where || is
|
// MixHash updates s.h to be BLAKE2s(s.h || data), where || is
|
||||||
|
@ -42,6 +42,12 @@ func TestHandshake(t *testing.T) {
|
|||||||
t.Fatal("client and server disagree on handshake hash")
|
t.Fatal("client and server disagree on handshake hash")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if client.ProtocolVersion() != protocolVersion {
|
||||||
|
t.Fatalf("client reporting wrong protocol version %d, want %d", client.ProtocolVersion(), protocolVersion)
|
||||||
|
}
|
||||||
|
if client.ProtocolVersion() != server.ProtocolVersion() {
|
||||||
|
t.Fatalf("peers disagree on protocol version, client=%d server=%d", client.ProtocolVersion(), server.ProtocolVersion())
|
||||||
|
}
|
||||||
if client.Peer() != serverKey.Public() {
|
if client.Peer() != serverKey.Public() {
|
||||||
t.Fatal("client peer key isn't serverKey")
|
t.Fatal("client peer key isn't serverKey")
|
||||||
}
|
}
|
||||||
@ -154,7 +160,7 @@ func (r *tamperReader) Read(bs []byte) (int, error) {
|
|||||||
|
|
||||||
func TestTampering(t *testing.T) {
|
func TestTampering(t *testing.T) {
|
||||||
// Tamper with every byte of the client initiation message.
|
// Tamper with every byte of the client initiation message.
|
||||||
for i := 0; i < 96; i++ {
|
for i := 0; i < 101; i++ {
|
||||||
var (
|
var (
|
||||||
clientConn, serverRaw = tsnettest.NewConn("noise", 128000)
|
clientConn, serverRaw = tsnettest.NewConn("noise", 128000)
|
||||||
serverConn = &readerConn{serverRaw, &tamperReader{serverRaw, i, 0}}
|
serverConn = &readerConn{serverRaw, &tamperReader{serverRaw, i, 0}}
|
||||||
@ -182,7 +188,7 @@ func TestTampering(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Tamper with every byte of the server response message.
|
// Tamper with every byte of the server response message.
|
||||||
for i := 0; i < 48; i++ {
|
for i := 0; i < 53; i++ {
|
||||||
var (
|
var (
|
||||||
clientRaw, serverConn = tsnettest.NewConn("noise", 128000)
|
clientRaw, serverConn = tsnettest.NewConn("noise", 128000)
|
||||||
clientConn = &readerConn{clientRaw, &tamperReader{clientRaw, i, 0}}
|
clientConn = &readerConn{clientRaw, &tamperReader{clientRaw, i, 0}}
|
||||||
@ -210,7 +216,7 @@ func TestTampering(t *testing.T) {
|
|||||||
for i := 0; i < 32; i++ {
|
for i := 0; i < 32; i++ {
|
||||||
var (
|
var (
|
||||||
clientRaw, serverConn = tsnettest.NewConn("noise", 128000)
|
clientRaw, serverConn = tsnettest.NewConn("noise", 128000)
|
||||||
clientConn = &readerConn{clientRaw, &tamperReader{clientRaw, 48 + i, 0}}
|
clientConn = &readerConn{clientRaw, &tamperReader{clientRaw, 53 + i, 0}}
|
||||||
serverKey = key.NewPrivate()
|
serverKey = key.NewPrivate()
|
||||||
clientKey = key.NewPrivate()
|
clientKey = key.NewPrivate()
|
||||||
serverErr = make(chan error, 1)
|
serverErr = make(chan error, 1)
|
||||||
@ -233,7 +239,7 @@ func TestTampering(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// The client needs a timeout if the tampering is hitting the length header.
|
// The client needs a timeout if the tampering is hitting the length header.
|
||||||
if i == 0 || i == 1 {
|
if i == 3 || i == 4 {
|
||||||
client.SetReadDeadline(time.Now().Add(10 * time.Millisecond))
|
client.SetReadDeadline(time.Now().Add(10 * time.Millisecond))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -251,7 +257,7 @@ func TestTampering(t *testing.T) {
|
|||||||
for i := 0; i < 32; i++ {
|
for i := 0; i < 32; i++ {
|
||||||
var (
|
var (
|
||||||
clientConn, serverRaw = tsnettest.NewConn("noise", 128000)
|
clientConn, serverRaw = tsnettest.NewConn("noise", 128000)
|
||||||
serverConn = &readerConn{serverRaw, &tamperReader{serverRaw, 96 + i, 0}}
|
serverConn = &readerConn{serverRaw, &tamperReader{serverRaw, 101 + i, 0}}
|
||||||
serverKey = key.NewPrivate()
|
serverKey = key.NewPrivate()
|
||||||
clientKey = key.NewPrivate()
|
clientKey = key.NewPrivate()
|
||||||
serverErr = make(chan error, 1)
|
serverErr = make(chan error, 1)
|
||||||
@ -261,7 +267,7 @@ func TestTampering(t *testing.T) {
|
|||||||
serverErr <- err
|
serverErr <- err
|
||||||
var bs [100]byte
|
var bs [100]byte
|
||||||
// The server needs a timeout if the tampering is hitting the length header.
|
// The server needs a timeout if the tampering is hitting the length header.
|
||||||
if i == 0 || i == 1 {
|
if i == 3 || i == 4 {
|
||||||
server.SetReadDeadline(time.Now().Add(10 * time.Millisecond))
|
server.SetReadDeadline(time.Now().Add(10 * time.Millisecond))
|
||||||
}
|
}
|
||||||
n, err := server.Read(bs[:])
|
n, err := server.Read(bs[:])
|
||||||
|
@ -120,9 +120,14 @@ func noiseExplorerClient(conn net.Conn, controlKey key.Public, machineKey key.Pr
|
|||||||
private_key: machineKey,
|
private_key: machineKey,
|
||||||
public_key: machineKey.Public(),
|
public_key: machineKey.Public(),
|
||||||
}
|
}
|
||||||
session := InitSession(true, []byte(protocolVersion), mk, controlKey)
|
session := InitSession(true, protocolVersionPrologue(protocolVersion), mk, controlKey)
|
||||||
|
|
||||||
_, msg1 := SendMessage(&session, nil)
|
_, msg1 := SendMessage(&session, nil)
|
||||||
|
var hdr [headerLen]byte
|
||||||
|
setHeader(hdr[:], protocolVersion, msgTypeInitiation, 96)
|
||||||
|
if _, err := conn.Write(hdr[:]); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
if _, err := conn.Write(msg1.ne[:]); err != nil {
|
if _, err := conn.Write(msg1.ne[:]); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -134,13 +139,15 @@ func noiseExplorerClient(conn net.Conn, controlKey key.Public, machineKey key.Pr
|
|||||||
}
|
}
|
||||||
|
|
||||||
var buf [1024]byte
|
var buf [1024]byte
|
||||||
if _, err := io.ReadFull(conn, buf[:48]); err != nil {
|
if _, err := io.ReadFull(conn, buf[:53]); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
// ignore the header for this test, we're only checking the noise
|
||||||
|
// implementation.
|
||||||
msg2 := messagebuffer{
|
msg2 := messagebuffer{
|
||||||
ciphertext: buf[32:48],
|
ciphertext: buf[37:53],
|
||||||
}
|
}
|
||||||
copy(msg2.ne[:], buf[:32])
|
copy(msg2.ne[:], buf[5:37])
|
||||||
_, p, valid := RecvMessage(&session, &msg2)
|
_, p, valid := RecvMessage(&session, &msg2)
|
||||||
if !valid {
|
if !valid {
|
||||||
return nil, errors.New("handshake failed")
|
return nil, errors.New("handshake failed")
|
||||||
@ -150,18 +157,19 @@ func noiseExplorerClient(conn net.Conn, controlKey key.Public, machineKey key.Pr
|
|||||||
}
|
}
|
||||||
|
|
||||||
_, msg3 := SendMessage(&session, payload)
|
_, msg3 := SendMessage(&session, payload)
|
||||||
binary.BigEndian.PutUint16(buf[:2], uint16(len(msg3.ciphertext)))
|
setHeader(hdr[:], protocolVersion, msgTypeRecord, len(msg3.ciphertext))
|
||||||
if _, err := conn.Write(buf[:2]); err != nil {
|
if _, err := conn.Write(hdr[:]); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if _, err := conn.Write(msg3.ciphertext); err != nil {
|
if _, err := conn.Write(msg3.ciphertext); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := io.ReadFull(conn, buf[:2]); err != nil {
|
if _, err := io.ReadFull(conn, buf[:5]); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
plen := int(binary.BigEndian.Uint16(buf[:2]))
|
// Ignore all of the header except the payload length
|
||||||
|
plen := int(binary.LittleEndian.Uint16(buf[3:5]))
|
||||||
if _, err := io.ReadFull(conn, buf[:plen]); err != nil {
|
if _, err := io.ReadFull(conn, buf[:plen]); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -182,17 +190,18 @@ func noiseExplorerServer(conn net.Conn, controlKey key.Private, wantMachineKey k
|
|||||||
private_key: controlKey,
|
private_key: controlKey,
|
||||||
public_key: controlKey.Public(),
|
public_key: controlKey.Public(),
|
||||||
}
|
}
|
||||||
session := InitSession(false, []byte(protocolVersion), mk, [32]byte{})
|
session := InitSession(false, protocolVersionPrologue(protocolVersion), mk, [32]byte{})
|
||||||
|
|
||||||
var buf [1024]byte
|
var buf [1024]byte
|
||||||
if _, err := io.ReadFull(conn, buf[:96]); err != nil {
|
if _, err := io.ReadFull(conn, buf[:101]); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
// Ignore the header, we're just checking the noise implementation.
|
||||||
msg1 := messagebuffer{
|
msg1 := messagebuffer{
|
||||||
ns: buf[32:80],
|
ns: buf[37:85],
|
||||||
ciphertext: buf[80:96],
|
ciphertext: buf[85:101],
|
||||||
}
|
}
|
||||||
copy(msg1.ne[:], buf[:32])
|
copy(msg1.ne[:], buf[5:37])
|
||||||
_, p, valid := RecvMessage(&session, &msg1)
|
_, p, valid := RecvMessage(&session, &msg1)
|
||||||
if !valid {
|
if !valid {
|
||||||
return nil, errors.New("handshake failed")
|
return nil, errors.New("handshake failed")
|
||||||
@ -202,6 +211,11 @@ func noiseExplorerServer(conn net.Conn, controlKey key.Private, wantMachineKey k
|
|||||||
}
|
}
|
||||||
|
|
||||||
_, msg2 := SendMessage(&session, nil)
|
_, msg2 := SendMessage(&session, nil)
|
||||||
|
var hdr [headerLen]byte
|
||||||
|
setHeader(hdr[:], protocolVersion, msgTypeResponse, 48)
|
||||||
|
if _, err := conn.Write(hdr[:]); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
if _, err := conn.Write(msg2.ne[:]); err != nil {
|
if _, err := conn.Write(msg2.ne[:]); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -209,10 +223,10 @@ func noiseExplorerServer(conn net.Conn, controlKey key.Private, wantMachineKey k
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := io.ReadFull(conn, buf[:2]); err != nil {
|
if _, err := io.ReadFull(conn, buf[:5]); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
plen := int(binary.BigEndian.Uint16(buf[:2]))
|
plen := int(binary.LittleEndian.Uint16(buf[3:5]))
|
||||||
if _, err := io.ReadFull(conn, buf[:plen]); err != nil {
|
if _, err := io.ReadFull(conn, buf[:plen]); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -226,8 +240,8 @@ func noiseExplorerServer(conn net.Conn, controlKey key.Private, wantMachineKey k
|
|||||||
}
|
}
|
||||||
|
|
||||||
_, msg4 := SendMessage(&session, payload)
|
_, msg4 := SendMessage(&session, payload)
|
||||||
binary.BigEndian.PutUint16(buf[:2], uint16(len(msg4.ciphertext)))
|
setHeader(hdr[:], protocolVersion, msgTypeRecord, len(msg4.ciphertext))
|
||||||
if _, err := conn.Write(buf[:2]); err != nil {
|
if _, err := conn.Write(hdr[:]); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if _, err := conn.Write(msg4.ciphertext); err != nil {
|
if _, err := conn.Write(msg4.ciphertext); err != nil {
|
||||||
|
26
control/noise/key.go
Normal file
26
control/noise/key.go
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package noise
|
||||||
|
|
||||||
|
// Note that these types are deliberately separate from the types/key
|
||||||
|
// package. That package defines generic curve25519 keys, without
|
||||||
|
// consideration for how those keys are used. We don't want to
|
||||||
|
// encourage mixing machine keys, node keys, and whatever else we
|
||||||
|
// might use curve25519 for.
|
||||||
|
//
|
||||||
|
// Furthermore, the implementation in types/key does some work that is
|
||||||
|
// unnecessary for machine keys, and results in a harder to follow
|
||||||
|
// implementation. In particular, machine keys do not need to be
|
||||||
|
// clamped per the curve25519 spec because they're only used with the
|
||||||
|
// X25519 operation, and the X25519 operation defines its own clamping
|
||||||
|
// and sanity checking logic. Thus, these keys must be used only with
|
||||||
|
// this Noise protocol implementation, and the easiest way to ensure
|
||||||
|
// that is a different type.
|
||||||
|
|
||||||
|
// PrivateKey is a Tailscale machine private key.
|
||||||
|
type PrivateKey [32]byte
|
||||||
|
|
||||||
|
// PublicKey is a Tailscale machine public key.
|
||||||
|
type PublicKey [32]byte
|
87
control/noise/messages.go
Normal file
87
control/noise/messages.go
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package noise
|
||||||
|
|
||||||
|
import "encoding/binary"
|
||||||
|
|
||||||
|
const (
|
||||||
|
msgTypeInitiation = 1
|
||||||
|
msgTypeResponse = 2
|
||||||
|
msgTypeError = 3
|
||||||
|
msgTypeRecord = 4
|
||||||
|
)
|
||||||
|
|
||||||
|
// headerLen is the size of the cleartext message header that gets
|
||||||
|
// prepended to Noise messages.
|
||||||
|
//
|
||||||
|
// 2b: protocol version
|
||||||
|
// 1b: message type
|
||||||
|
// 2b: payload length (not including this header)
|
||||||
|
const headerLen = 5
|
||||||
|
|
||||||
|
func setHeader(bs []byte, version int, msgType byte, length int) {
|
||||||
|
binary.LittleEndian.PutUint16(bs[:2], uint16(version))
|
||||||
|
bs[2] = msgType
|
||||||
|
binary.LittleEndian.PutUint16(bs[3:5], uint16(length))
|
||||||
|
}
|
||||||
|
func hdrVersion(bs []byte) int { return int(binary.LittleEndian.Uint16(bs[:2])) }
|
||||||
|
func hdrType(bs []byte) byte { return bs[2] }
|
||||||
|
func hdrLen(bs []byte) int { return int(binary.LittleEndian.Uint16(bs[3:5])) }
|
||||||
|
|
||||||
|
// initiationMessage is the Noise protocol message sent from a client
|
||||||
|
// machine to a control server.
|
||||||
|
//
|
||||||
|
// 5b: header (see headerLen for fields)
|
||||||
|
// 32b: client ephemeral public key (cleartext)
|
||||||
|
// 48b: client machine public key (encrypted)
|
||||||
|
// 16b: message tag (authenticates the whole message)
|
||||||
|
type initiationMessage [101]byte
|
||||||
|
|
||||||
|
func mkInitiationMessage() initiationMessage {
|
||||||
|
var ret initiationMessage
|
||||||
|
binary.LittleEndian.PutUint16(ret[:2], protocolVersion)
|
||||||
|
ret[2] = msgTypeInitiation
|
||||||
|
binary.LittleEndian.PutUint16(ret[3:5], 96)
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *initiationMessage) Header() []byte { return m[:5] }
|
||||||
|
func (m *initiationMessage) Payload() []byte { return m[5:] }
|
||||||
|
|
||||||
|
func (m *initiationMessage) Version() int { return hdrVersion(m.Header()) }
|
||||||
|
func (m *initiationMessage) Type() byte { return hdrType(m.Header()) }
|
||||||
|
func (m *initiationMessage) Length() int { return hdrLen(m.Header()) }
|
||||||
|
|
||||||
|
func (m *initiationMessage) EphemeralPub() []byte { return m[5:37] }
|
||||||
|
func (m *initiationMessage) MachinePub() []byte { return m[37:85] }
|
||||||
|
func (m *initiationMessage) Tag() []byte { return m[85:] }
|
||||||
|
|
||||||
|
// responseMessage is the Noise protocol message sent from a control
|
||||||
|
// server to a client machine.
|
||||||
|
//
|
||||||
|
// 2b: little-endian protocol version
|
||||||
|
// 1b: message type
|
||||||
|
// 2b: little-endian size of message (not including this header)
|
||||||
|
// 32b: control ephemeral public key (cleartext)
|
||||||
|
// 16b: message tag (authenticates the whole message)
|
||||||
|
type responseMessage [53]byte
|
||||||
|
|
||||||
|
func mkResponseMessage() responseMessage {
|
||||||
|
var ret responseMessage
|
||||||
|
binary.LittleEndian.PutUint16(ret[:2], protocolVersion)
|
||||||
|
ret[2] = msgTypeResponse
|
||||||
|
binary.LittleEndian.PutUint16(ret[3:5], 48)
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *responseMessage) Header() []byte { return m[:5] }
|
||||||
|
func (m *responseMessage) Payload() []byte { return m[5:] }
|
||||||
|
|
||||||
|
func (m *responseMessage) Version() int { return hdrVersion(m.Header()) }
|
||||||
|
func (m *responseMessage) Type() byte { return hdrType(m.Header()) }
|
||||||
|
func (m *responseMessage) Length() int { return hdrLen(m.Header()) }
|
||||||
|
|
||||||
|
func (m *responseMessage) EphemeralPub() []byte { return m[5:37] }
|
||||||
|
func (m *responseMessage) Tag() []byte { return m[37:] }
|
Loading…
x
Reference in New Issue
Block a user