tailscale/control/noise/conn.go

331 lines
9.4 KiB
Go
Raw Normal View History

// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package noise implements the base transport of the Tailscale 2021
// control protocol.
//
// The base transport implements Noise IK, instantiated with
// Curve25519, ChaCha20Poly1305 and BLAKE2s.
package noise
import (
"crypto/cipher"
"encoding/binary"
"fmt"
"net"
"sync"
"time"
"golang.org/x/crypto/blake2s"
chp "golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/poly1305"
"tailscale.com/types/key"
)
const (
maxPlaintextSize = 4096
maxCiphertextSize = maxPlaintextSize + poly1305.TagSize
maxPacketSize = maxCiphertextSize + 2 // ciphertext + length header
)
// A Conn is a secured Noise connection. It implements the net.Conn
// interface, with the unusual trait that any write error (including a
// SetWriteDeadline induced i/o timeout) cause all future writes to
// fail.
type Conn struct {
conn net.Conn
peer key.Public
handshakeHash [blake2s.Size]byte
rx rxState
tx txState
}
// rxState is all the Conn state that Read uses.
type rxState struct {
sync.Mutex
cipher cipher.AEAD
nonce [chp.NonceSize]byte
buf [maxPacketSize]byte
n int // number of valid bytes in buf
next int // offset of next undecrypted packet
plaintext []byte // slice into buf of decrypted bytes
}
// txState is all the Conn state that Write uses.
type txState struct {
sync.Mutex
cipher cipher.AEAD
nonce [chp.NonceSize]byte
buf [maxPacketSize]byte
err error // records the first partial write error for all future calls
}
// HandshakeHash returns the Noise handshake hash for the connection,
// which can be used to bind other messages to this connection
// (i.e. to ensure that the message wasn't replayed from a different
// connection).
func (c *Conn) HandshakeHash() [blake2s.Size]byte {
return c.handshakeHash
}
// Peer returns the peer's long-term public key.
func (c *Conn) Peer() key.Public {
return c.peer
}
// validNonce reports whether nonce is in the valid range for use: 0
// through 2^64-2.
func validNonce(nonce []byte) bool {
return binary.BigEndian.Uint32(nonce[:4]) == 0 && binary.BigEndian.Uint64(nonce[4:]) != invalidNonce
}
// readNLocked reads into c.rxBuf until rxBuf contains at least total
// bytes. Returns a slice of the available bytes in rxBuf, or an
// error if fewer than total bytes are available.
func (c *Conn) readNLocked(total int) ([]byte, error) {
if total > maxPacketSize {
return nil, errReadTooBig{total}
}
for {
if total <= c.rx.n {
return c.rx.buf[:c.rx.n], nil
}
n, err := c.conn.Read(c.rx.buf[c.rx.n:])
c.rx.n += n
if err != nil {
return nil, err
}
}
}
// decryptLocked decrypts ciphertext in-place and sets c.rx.plaintext
// to the decrypted bytes. Returns an error if the cipher is exhausted
// (i.e. can no longer be used safely) or decryption fails.
func (c *Conn) decryptLocked(ciphertext []byte) (err error) {
if !validNonce(c.rx.nonce[:]) {
return errCipherExhausted{}
}
c.rx.plaintext, err = c.rx.cipher.Open(ciphertext[:0], c.rx.nonce[:], ciphertext, nil)
// Safe to increment the nonce here, because we checked for nonce
// wraparound above.
binary.BigEndian.PutUint64(c.rx.nonce[4:], 1+binary.BigEndian.Uint64(c.rx.nonce[4:]))
if err != nil {
// Once a decryption has failed, our Conn is no longer
// synchronized with our peer. Nuke the cipher state to be
// safe, so that no further decryptions are attempted.
c.rx.cipher = nil
}
return err
}
// encryptLocked encrypts plaintext into c.tx.buf (including the
// 2-byte length header) and returns a slice of the ciphertext, or an
// error if the cipher is exhausted (i.e. can no longer be used safely).
func (c *Conn) encryptLocked(plaintext []byte) ([]byte, error) {
if !validNonce(c.tx.nonce[:]) {
// Received 2^64-1 messages on this cipher state. Connection
// is no longer usable.
return nil, errCipherExhausted{}
}
binary.BigEndian.PutUint16(c.tx.buf[:2], uint16(len(plaintext)+poly1305.TagSize))
ret := c.tx.cipher.Seal(c.tx.buf[:2], c.tx.nonce[:], plaintext, nil)
// Safe to increment the nonce here, because we checked for nonce
// wraparound above.
binary.BigEndian.PutUint64(c.tx.nonce[4:], 1+binary.BigEndian.Uint64(c.tx.nonce[4:]))
return ret, nil
}
// wholeCiphertextLocked returns a slice of one whole Noise frame from
// c.rx.buf, if one whole ciphertext is available, and advances the
// read state to the next Noise frame in the buffer. Returns nil
// without advancing read state if there's not one whole ciphertext in
// c.rx.buf.
func (c *Conn) wholeCiphertextLocked() []byte {
available := c.rx.n - c.rx.next
if available < 2 {
return nil
}
bs := c.rx.buf[c.rx.next:c.rx.n]
totalSize := int(binary.BigEndian.Uint16(bs[:2])) + 2
if len(bs) < totalSize {
return nil
}
c.rx.next += totalSize
return bs[:totalSize]
}
// decryptOneLocked decrypts one Noise frame, reading from c.conn as needed,
// and sets c.rx.plaintext to point to the decrypted
// bytes. c.rx.plaintext is only valid if err == nil.
func (c *Conn) decryptOneLocked() error {
c.rx.plaintext = nil
// Fast path: do we have one whole ciphertext frame buffered
// already?
if bs := c.wholeCiphertextLocked(); bs != nil {
return c.decryptLocked(bs[2:])
}
if c.rx.next != 0 {
// To simplify the read logic, move the remainder of the
// buffered bytes back to the head of the buffer, so we can
// grow it without worrying about wraparound.
copy(c.rx.buf[:], c.rx.buf[c.rx.next:c.rx.n])
c.rx.n -= c.rx.next
c.rx.next = 0
}
bs, err := c.readNLocked(2)
if err != nil {
return err
}
totalLen := int(binary.BigEndian.Uint16(bs[:2])) + 2
bs, err = c.readNLocked(totalLen)
if err != nil {
return err
}
c.rx.next = totalLen
bs = bs[2:totalLen]
return c.decryptLocked(bs)
}
// Read implements io.Reader.
func (c *Conn) Read(bs []byte) (int, error) {
c.rx.Lock()
defer c.rx.Unlock()
if c.rx.cipher == nil {
return 0, net.ErrClosed
}
// Loop to handle receiving a zero-byte Noise message. Just skip
// over it and keep decrypting until we find some bytes.
for len(c.rx.plaintext) == 0 {
if err := c.decryptOneLocked(); err != nil {
return 0, err
}
}
n := copy(bs, c.rx.plaintext)
c.rx.plaintext = c.rx.plaintext[n:]
return n, nil
}
// Write implements io.Writer.
func (c *Conn) Write(bs []byte) (n int, err error) {
c.tx.Lock()
defer c.tx.Unlock()
if c.tx.err != nil {
return 0, c.tx.err
}
defer func() {
if err != nil {
// All write errors are fatal for this conn, so clear the
// cipher state whenever an error happens.
c.tx.cipher = nil
}
if c.tx.err == nil {
// Only set c.tx.err if not nil so that we can return one
// error on the first failure, and a different one for
// subsequent calls. See the error handling around Write
// below for why.
c.tx.err = err
}
}()
if c.tx.cipher == nil {
return 0, net.ErrClosed
}
var sent int
for len(bs) > 0 {
toSend := bs
if len(toSend) > maxPlaintextSize {
toSend = bs[:maxPlaintextSize]
}
bs = bs[len(toSend):]
ciphertext, err := c.encryptLocked(toSend)
if err != nil {
return 0, err
}
if n, err := c.conn.Write(ciphertext); err != nil {
sent += n
// Return the raw error on the Write that actually
// failed. For future writes, return that error wrapped in
// a desync error.
c.tx.err = errPartialWrite{err}
return sent, err
}
sent += len(toSend)
}
return sent, nil
}
// Close implements io.Closer.
func (c *Conn) Close() error {
closeErr := c.conn.Close() // unblocks any waiting reads or writes
c.rx.Lock()
c.rx.cipher = nil
c.rx.Unlock()
c.tx.Lock()
c.tx.cipher = nil
c.tx.Unlock()
return closeErr
}
func (c *Conn) LocalAddr() net.Addr { return c.conn.LocalAddr() }
func (c *Conn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() }
func (c *Conn) SetDeadline(t time.Time) error { return c.conn.SetDeadline(t) }
func (c *Conn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) }
func (c *Conn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) }
// errCipherExhausted is the error returned when we run out of nonces
// on a cipher.
type errCipherExhausted struct{}
func (errCipherExhausted) Error() string {
return "cipher exhausted, no more nonces available for current key"
}
func (errCipherExhausted) Timeout() bool { return false }
func (errCipherExhausted) Temporary() bool { return false }
// errPartialWrite is the error returned when the cipher state has
// become unusable due to a past partial write.
type errPartialWrite struct {
err error
}
func (e errPartialWrite) Error() string {
return fmt.Sprintf("cipher state desynchronized due to partial write (%v)", e.err)
}
func (e errPartialWrite) Unwrap() error { return e.err }
func (e errPartialWrite) Temporary() bool { return false }
func (e errPartialWrite) Timeout() bool { return false }
// errReadTooBig is the error returned when the peer sent an
// unacceptably large Noise frame.
type errReadTooBig struct {
requested int
}
func (e errReadTooBig) Error() string {
return fmt.Sprintf("requested read of %d bytes exceeds max allowed Noise frame size", e.requested)
}
func (e errReadTooBig) Temporary() bool {
// permanent error because this error only occurs when our peer
// sends us a frame so large we're unwilling to ever decode it.
return false
}
func (e errReadTooBig) Timeout() bool { return false }