mirror of
https://github.com/tailscale/tailscale.git
synced 2024-11-26 11:35:35 +00:00
331 lines
9.4 KiB
Go
331 lines
9.4 KiB
Go
|
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
|
||
|
// Use of this source code is governed by a BSD-style
|
||
|
// license that can be found in the LICENSE file.
|
||
|
|
||
|
// Package noise implements the base transport of the Tailscale 2021
|
||
|
// control protocol.
|
||
|
//
|
||
|
// The base transport implements Noise IK, instantiated with
|
||
|
// Curve25519, ChaCha20Poly1305 and BLAKE2s.
|
||
|
package noise
|
||
|
|
||
|
import (
|
||
|
"crypto/cipher"
|
||
|
"encoding/binary"
|
||
|
"fmt"
|
||
|
"net"
|
||
|
"sync"
|
||
|
"time"
|
||
|
|
||
|
"golang.org/x/crypto/blake2s"
|
||
|
chp "golang.org/x/crypto/chacha20poly1305"
|
||
|
"golang.org/x/crypto/poly1305"
|
||
|
"tailscale.com/types/key"
|
||
|
)
|
||
|
|
||
|
const (
|
||
|
maxPlaintextSize = 4096
|
||
|
maxCiphertextSize = maxPlaintextSize + poly1305.TagSize
|
||
|
maxPacketSize = maxCiphertextSize + 2 // ciphertext + length header
|
||
|
)
|
||
|
|
||
|
// A Conn is a secured Noise connection. It implements the net.Conn
|
||
|
// interface, with the unusual trait that any write error (including a
|
||
|
// SetWriteDeadline induced i/o timeout) cause all future writes to
|
||
|
// fail.
|
||
|
type Conn struct {
|
||
|
conn net.Conn
|
||
|
peer key.Public
|
||
|
handshakeHash [blake2s.Size]byte
|
||
|
rx rxState
|
||
|
tx txState
|
||
|
}
|
||
|
|
||
|
// rxState is all the Conn state that Read uses.
|
||
|
type rxState struct {
|
||
|
sync.Mutex
|
||
|
cipher cipher.AEAD
|
||
|
nonce [chp.NonceSize]byte
|
||
|
buf [maxPacketSize]byte
|
||
|
n int // number of valid bytes in buf
|
||
|
next int // offset of next undecrypted packet
|
||
|
plaintext []byte // slice into buf of decrypted bytes
|
||
|
}
|
||
|
|
||
|
// txState is all the Conn state that Write uses.
|
||
|
type txState struct {
|
||
|
sync.Mutex
|
||
|
cipher cipher.AEAD
|
||
|
nonce [chp.NonceSize]byte
|
||
|
buf [maxPacketSize]byte
|
||
|
err error // records the first partial write error for all future calls
|
||
|
}
|
||
|
|
||
|
// HandshakeHash returns the Noise handshake hash for the connection,
|
||
|
// which can be used to bind other messages to this connection
|
||
|
// (i.e. to ensure that the message wasn't replayed from a different
|
||
|
// connection).
|
||
|
func (c *Conn) HandshakeHash() [blake2s.Size]byte {
|
||
|
return c.handshakeHash
|
||
|
}
|
||
|
|
||
|
// Peer returns the peer's long-term public key.
|
||
|
func (c *Conn) Peer() key.Public {
|
||
|
return c.peer
|
||
|
}
|
||
|
|
||
|
// validNonce reports whether nonce is in the valid range for use: 0
|
||
|
// through 2^64-2.
|
||
|
func validNonce(nonce []byte) bool {
|
||
|
return binary.BigEndian.Uint32(nonce[:4]) == 0 && binary.BigEndian.Uint64(nonce[4:]) != invalidNonce
|
||
|
}
|
||
|
|
||
|
// readNLocked reads into c.rxBuf until rxBuf contains at least total
|
||
|
// bytes. Returns a slice of the available bytes in rxBuf, or an
|
||
|
// error if fewer than total bytes are available.
|
||
|
func (c *Conn) readNLocked(total int) ([]byte, error) {
|
||
|
if total > maxPacketSize {
|
||
|
return nil, errReadTooBig{total}
|
||
|
}
|
||
|
for {
|
||
|
if total <= c.rx.n {
|
||
|
return c.rx.buf[:c.rx.n], nil
|
||
|
}
|
||
|
|
||
|
n, err := c.conn.Read(c.rx.buf[c.rx.n:])
|
||
|
c.rx.n += n
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// decryptLocked decrypts ciphertext in-place and sets c.rx.plaintext
|
||
|
// to the decrypted bytes. Returns an error if the cipher is exhausted
|
||
|
// (i.e. can no longer be used safely) or decryption fails.
|
||
|
func (c *Conn) decryptLocked(ciphertext []byte) (err error) {
|
||
|
if !validNonce(c.rx.nonce[:]) {
|
||
|
return errCipherExhausted{}
|
||
|
}
|
||
|
|
||
|
c.rx.plaintext, err = c.rx.cipher.Open(ciphertext[:0], c.rx.nonce[:], ciphertext, nil)
|
||
|
|
||
|
// Safe to increment the nonce here, because we checked for nonce
|
||
|
// wraparound above.
|
||
|
binary.BigEndian.PutUint64(c.rx.nonce[4:], 1+binary.BigEndian.Uint64(c.rx.nonce[4:]))
|
||
|
|
||
|
if err != nil {
|
||
|
// Once a decryption has failed, our Conn is no longer
|
||
|
// synchronized with our peer. Nuke the cipher state to be
|
||
|
// safe, so that no further decryptions are attempted.
|
||
|
c.rx.cipher = nil
|
||
|
}
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
// encryptLocked encrypts plaintext into c.tx.buf (including the
|
||
|
// 2-byte length header) and returns a slice of the ciphertext, or an
|
||
|
// error if the cipher is exhausted (i.e. can no longer be used safely).
|
||
|
func (c *Conn) encryptLocked(plaintext []byte) ([]byte, error) {
|
||
|
if !validNonce(c.tx.nonce[:]) {
|
||
|
// Received 2^64-1 messages on this cipher state. Connection
|
||
|
// is no longer usable.
|
||
|
return nil, errCipherExhausted{}
|
||
|
}
|
||
|
|
||
|
binary.BigEndian.PutUint16(c.tx.buf[:2], uint16(len(plaintext)+poly1305.TagSize))
|
||
|
ret := c.tx.cipher.Seal(c.tx.buf[:2], c.tx.nonce[:], plaintext, nil)
|
||
|
|
||
|
// Safe to increment the nonce here, because we checked for nonce
|
||
|
// wraparound above.
|
||
|
binary.BigEndian.PutUint64(c.tx.nonce[4:], 1+binary.BigEndian.Uint64(c.tx.nonce[4:]))
|
||
|
|
||
|
return ret, nil
|
||
|
}
|
||
|
|
||
|
// wholeCiphertextLocked returns a slice of one whole Noise frame from
|
||
|
// c.rx.buf, if one whole ciphertext is available, and advances the
|
||
|
// read state to the next Noise frame in the buffer. Returns nil
|
||
|
// without advancing read state if there's not one whole ciphertext in
|
||
|
// c.rx.buf.
|
||
|
func (c *Conn) wholeCiphertextLocked() []byte {
|
||
|
available := c.rx.n - c.rx.next
|
||
|
if available < 2 {
|
||
|
return nil
|
||
|
}
|
||
|
bs := c.rx.buf[c.rx.next:c.rx.n]
|
||
|
totalSize := int(binary.BigEndian.Uint16(bs[:2])) + 2
|
||
|
if len(bs) < totalSize {
|
||
|
return nil
|
||
|
}
|
||
|
c.rx.next += totalSize
|
||
|
return bs[:totalSize]
|
||
|
}
|
||
|
|
||
|
// decryptOneLocked decrypts one Noise frame, reading from c.conn as needed,
|
||
|
// and sets c.rx.plaintext to point to the decrypted
|
||
|
// bytes. c.rx.plaintext is only valid if err == nil.
|
||
|
func (c *Conn) decryptOneLocked() error {
|
||
|
c.rx.plaintext = nil
|
||
|
|
||
|
// Fast path: do we have one whole ciphertext frame buffered
|
||
|
// already?
|
||
|
if bs := c.wholeCiphertextLocked(); bs != nil {
|
||
|
return c.decryptLocked(bs[2:])
|
||
|
}
|
||
|
|
||
|
if c.rx.next != 0 {
|
||
|
// To simplify the read logic, move the remainder of the
|
||
|
// buffered bytes back to the head of the buffer, so we can
|
||
|
// grow it without worrying about wraparound.
|
||
|
copy(c.rx.buf[:], c.rx.buf[c.rx.next:c.rx.n])
|
||
|
c.rx.n -= c.rx.next
|
||
|
c.rx.next = 0
|
||
|
}
|
||
|
|
||
|
bs, err := c.readNLocked(2)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
totalLen := int(binary.BigEndian.Uint16(bs[:2])) + 2
|
||
|
bs, err = c.readNLocked(totalLen)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
c.rx.next = totalLen
|
||
|
bs = bs[2:totalLen]
|
||
|
|
||
|
return c.decryptLocked(bs)
|
||
|
}
|
||
|
|
||
|
// Read implements io.Reader.
|
||
|
func (c *Conn) Read(bs []byte) (int, error) {
|
||
|
c.rx.Lock()
|
||
|
defer c.rx.Unlock()
|
||
|
|
||
|
if c.rx.cipher == nil {
|
||
|
return 0, net.ErrClosed
|
||
|
}
|
||
|
// Loop to handle receiving a zero-byte Noise message. Just skip
|
||
|
// over it and keep decrypting until we find some bytes.
|
||
|
for len(c.rx.plaintext) == 0 {
|
||
|
if err := c.decryptOneLocked(); err != nil {
|
||
|
return 0, err
|
||
|
}
|
||
|
}
|
||
|
n := copy(bs, c.rx.plaintext)
|
||
|
c.rx.plaintext = c.rx.plaintext[n:]
|
||
|
return n, nil
|
||
|
}
|
||
|
|
||
|
// Write implements io.Writer.
|
||
|
func (c *Conn) Write(bs []byte) (n int, err error) {
|
||
|
c.tx.Lock()
|
||
|
defer c.tx.Unlock()
|
||
|
|
||
|
if c.tx.err != nil {
|
||
|
return 0, c.tx.err
|
||
|
}
|
||
|
defer func() {
|
||
|
if err != nil {
|
||
|
// All write errors are fatal for this conn, so clear the
|
||
|
// cipher state whenever an error happens.
|
||
|
c.tx.cipher = nil
|
||
|
}
|
||
|
if c.tx.err == nil {
|
||
|
// Only set c.tx.err if not nil so that we can return one
|
||
|
// error on the first failure, and a different one for
|
||
|
// subsequent calls. See the error handling around Write
|
||
|
// below for why.
|
||
|
c.tx.err = err
|
||
|
}
|
||
|
}()
|
||
|
|
||
|
if c.tx.cipher == nil {
|
||
|
return 0, net.ErrClosed
|
||
|
}
|
||
|
|
||
|
var sent int
|
||
|
for len(bs) > 0 {
|
||
|
toSend := bs
|
||
|
if len(toSend) > maxPlaintextSize {
|
||
|
toSend = bs[:maxPlaintextSize]
|
||
|
}
|
||
|
bs = bs[len(toSend):]
|
||
|
|
||
|
ciphertext, err := c.encryptLocked(toSend)
|
||
|
if err != nil {
|
||
|
return 0, err
|
||
|
}
|
||
|
|
||
|
if n, err := c.conn.Write(ciphertext); err != nil {
|
||
|
sent += n
|
||
|
// Return the raw error on the Write that actually
|
||
|
// failed. For future writes, return that error wrapped in
|
||
|
// a desync error.
|
||
|
c.tx.err = errPartialWrite{err}
|
||
|
return sent, err
|
||
|
}
|
||
|
sent += len(toSend)
|
||
|
}
|
||
|
return sent, nil
|
||
|
}
|
||
|
|
||
|
// Close implements io.Closer.
|
||
|
func (c *Conn) Close() error {
|
||
|
closeErr := c.conn.Close() // unblocks any waiting reads or writes
|
||
|
c.rx.Lock()
|
||
|
c.rx.cipher = nil
|
||
|
c.rx.Unlock()
|
||
|
c.tx.Lock()
|
||
|
c.tx.cipher = nil
|
||
|
c.tx.Unlock()
|
||
|
return closeErr
|
||
|
}
|
||
|
|
||
|
func (c *Conn) LocalAddr() net.Addr { return c.conn.LocalAddr() }
|
||
|
func (c *Conn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() }
|
||
|
func (c *Conn) SetDeadline(t time.Time) error { return c.conn.SetDeadline(t) }
|
||
|
func (c *Conn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) }
|
||
|
func (c *Conn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) }
|
||
|
|
||
|
// errCipherExhausted is the error returned when we run out of nonces
|
||
|
// on a cipher.
|
||
|
type errCipherExhausted struct{}
|
||
|
|
||
|
func (errCipherExhausted) Error() string {
|
||
|
return "cipher exhausted, no more nonces available for current key"
|
||
|
}
|
||
|
func (errCipherExhausted) Timeout() bool { return false }
|
||
|
func (errCipherExhausted) Temporary() bool { return false }
|
||
|
|
||
|
// errPartialWrite is the error returned when the cipher state has
|
||
|
// become unusable due to a past partial write.
|
||
|
type errPartialWrite struct {
|
||
|
err error
|
||
|
}
|
||
|
|
||
|
func (e errPartialWrite) Error() string {
|
||
|
return fmt.Sprintf("cipher state desynchronized due to partial write (%v)", e.err)
|
||
|
}
|
||
|
func (e errPartialWrite) Unwrap() error { return e.err }
|
||
|
func (e errPartialWrite) Temporary() bool { return false }
|
||
|
func (e errPartialWrite) Timeout() bool { return false }
|
||
|
|
||
|
// errReadTooBig is the error returned when the peer sent an
|
||
|
// unacceptably large Noise frame.
|
||
|
type errReadTooBig struct {
|
||
|
requested int
|
||
|
}
|
||
|
|
||
|
func (e errReadTooBig) Error() string {
|
||
|
return fmt.Sprintf("requested read of %d bytes exceeds max allowed Noise frame size", e.requested)
|
||
|
}
|
||
|
func (e errReadTooBig) Temporary() bool {
|
||
|
// permanent error because this error only occurs when our peer
|
||
|
// sends us a frame so large we're unwilling to ever decode it.
|
||
|
return false
|
||
|
}
|
||
|
func (e errReadTooBig) Timeout() bool { return false }
|