From d3acd35a906c5840f155c4fec74585017a55536a Mon Sep 17 00:00:00 2001 From: David Anderson Date: Mon, 25 Oct 2021 17:24:32 -0700 Subject: [PATCH] control/noise: make message headers match the specification. Only the initiation message should carry a protocol version, all others are just type+len. Signed-off-by: David Anderson --- control/noise/conn.go | 16 +++--- control/noise/handshake.go | 9 ++-- control/noise/handshake_test.go | 12 ++--- control/noise/interop_test.go | 31 +++++++----- control/noise/messages.go | 88 +++++++++++---------------------- 5 files changed, 65 insertions(+), 91 deletions(-) diff --git a/control/noise/conn.go b/control/noise/conn.go index 334253474..abb74b4cc 100644 --- a/control/noise/conn.go +++ b/control/noise/conn.go @@ -28,7 +28,7 @@ maxMessageSize = 4096 // maxCiphertextSize is the maximum amount of ciphertext bytes // that one protocol frame can carry, after framing. - maxCiphertextSize = maxMessageSize - headerLen + maxCiphertextSize = maxMessageSize - 3 // maxPlaintextSize is the maximum amount of plaintext bytes that // one protocol frame can carry, after encryption and framing. maxPlaintextSize = maxCiphertextSize - chp.Overhead @@ -115,11 +115,8 @@ func (c *Conn) readNLocked(total int) ([]byte, error) { // decryptLocked decrypts msg (which is header+ciphertext) in-place // and sets c.rx.plaintext to the decrypted bytes. func (c *Conn) decryptLocked(msg []byte) (err error) { - if hdrVersion(msg) != c.version { - return fmt.Errorf("received message with unexpected protocol version %d, want %d", hdrVersion(msg), c.version) - } - if hdrType(msg) != msgTypeRecord { - return fmt.Errorf("received message with unexpected type %d, want %d", hdrType(msg), msgTypeRecord) + if msgType := msg[0]; msgType != msgTypeRecord { + return fmt.Errorf("received message with unexpected type %d, want %d", msgType, msgTypeRecord) } // We don't check the length field here, because the caller // already did in order to figure out how big the msg slice should @@ -156,7 +153,8 @@ func (c *Conn) encryptLocked(plaintext []byte) ([]byte, error) { return nil, errCipherExhausted{} } - setHeader(c.tx.buf[:headerLen], protocolVersion, msgTypeRecord, len(plaintext)+chp.Overhead) + c.tx.buf[0] = msgTypeRecord + binary.BigEndian.PutUint16(c.tx.buf[1:headerLen], uint16(len(plaintext)+chp.Overhead)) ret := c.tx.cipher.Seal(c.tx.buf[:headerLen], c.tx.nonce[:], plaintext, nil) // Safe to increment the nonce here, because we checked for nonce @@ -177,7 +175,7 @@ func (c *Conn) wholeMessageLocked() []byte { return nil } bs := c.rx.buf[c.rx.next:c.rx.n] - totalSize := headerLen + hdrLen(bs) + totalSize := headerLen + int(binary.BigEndian.Uint16(bs[1:3])) if len(bs) < totalSize { return nil } @@ -211,7 +209,7 @@ func (c *Conn) decryptOneLocked() error { } // The rest of the header (besides the length field) gets verified // in decryptLocked, not here. - messageLen := headerLen + hdrLen(bs) + messageLen := headerLen + int(binary.BigEndian.Uint16(bs[1:3])) bs, err = c.readNLocked(messageLen) if err != nil { return err diff --git a/control/noise/handshake.go b/control/noise/handshake.go index 5ca02ea53..79419f62b 100644 --- a/control/noise/handshake.go +++ b/control/noise/handshake.go @@ -7,6 +7,7 @@ import ( "context" "crypto/cipher" + "encoding/binary" "errors" "fmt" "hash" @@ -101,9 +102,6 @@ func Client(ctx context.Context, conn net.Conn, machineKey key.MachinePrivate, c if _, err := io.ReadFull(conn, resp.Header()); err != nil { return nil, fmt.Errorf("reading response header: %w", err) } - if resp.Version() != protocolVersion { - return nil, fmt.Errorf("unexpected version %d from server, want %d", resp.Version(), protocolVersion) - } if resp.Type() != msgTypeResponse { if resp.Type() != msgTypeError { return nil, fmt.Errorf("unexpected response message type %d", resp.Type()) @@ -177,7 +175,8 @@ func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate) ( msg = msg[:1<<16] } var hdr [headerLen]byte - setHeader(hdr[:], protocolVersion, msgTypeError, len(msg)) + hdr[0] = msgTypeError + binary.BigEndian.PutUint16(hdr[1:3], uint16(len(msg))) if _, err := conn.Write(hdr[:]); err != nil { return fmt.Errorf("sending %q error to client: %w", msg, err) } @@ -283,7 +282,7 @@ type symmetricState struct { ck [blake2s.Size]byte // chaining key used to construct session keys at the end of the handshake } -func (s *symmetricState) checkFinished() error { +func (s *symmetricState) checkFinished() { if s.finished { panic("attempted to use symmetricState after Split was called") } diff --git a/control/noise/handshake_test.go b/control/noise/handshake_test.go index 108b60f56..52eb7fa7a 100644 --- a/control/noise/handshake_test.go +++ b/control/noise/handshake_test.go @@ -188,7 +188,7 @@ func TestTampering(t *testing.T) { } // Tamper with every byte of the server response message. - for i := 0; i < 53; i++ { + for i := 0; i < 51; i++ { var ( clientRaw, serverConn = tsnettest.NewConn("noise", 128000) clientConn = &readerConn{clientRaw, &tamperReader{clientRaw, i, 0}} @@ -213,10 +213,10 @@ func TestTampering(t *testing.T) { } // Tamper with every byte of the first server>client transport message. - for i := 0; i < 32; i++ { + for i := 0; i < 30; i++ { var ( clientRaw, serverConn = tsnettest.NewConn("noise", 128000) - clientConn = &readerConn{clientRaw, &tamperReader{clientRaw, 53 + i, 0}} + clientConn = &readerConn{clientRaw, &tamperReader{clientRaw, 51 + i, 0}} serverKey = key.NewMachine() clientKey = key.NewMachine() serverErr = make(chan error, 1) @@ -239,7 +239,7 @@ func TestTampering(t *testing.T) { } // The client needs a timeout if the tampering is hitting the length header. - if i == 3 || i == 4 { + if i == 1 || i == 2 { client.SetReadDeadline(time.Now().Add(10 * time.Millisecond)) } @@ -254,7 +254,7 @@ func TestTampering(t *testing.T) { } // Tamper with every byte of the first client>server transport message. - for i := 0; i < 32; i++ { + for i := 0; i < 30; i++ { var ( clientConn, serverRaw = tsnettest.NewConn("noise", 128000) serverConn = &readerConn{serverRaw, &tamperReader{serverRaw, 101 + i, 0}} @@ -267,7 +267,7 @@ func TestTampering(t *testing.T) { serverErr <- err var bs [100]byte // The server needs a timeout if the tampering is hitting the length header. - if i == 3 || i == 4 { + if i == 1 || i == 2 { server.SetReadDeadline(time.Now().Add(10 * time.Millisecond)) } n, err := server.Read(bs[:]) diff --git a/control/noise/interop_test.go b/control/noise/interop_test.go index 8c6f35342..cb2b8ae3b 100644 --- a/control/noise/interop_test.go +++ b/control/noise/interop_test.go @@ -124,8 +124,10 @@ func noiseExplorerClient(conn net.Conn, controlKey key.MachinePublic, machineKey session := InitSession(true, protocolVersionPrologue(protocolVersion), mk, peerKey) _, msg1 := SendMessage(&session, nil) - var hdr [headerLen]byte - setHeader(hdr[:], protocolVersion, msgTypeInitiation, 96) + var hdr [initiationHeaderLen]byte + binary.BigEndian.PutUint16(hdr[:2], protocolVersion) + hdr[2] = msgTypeInitiation + binary.BigEndian.PutUint16(hdr[3:5], 96) if _, err := conn.Write(hdr[:]); err != nil { return nil, err } @@ -140,15 +142,15 @@ func noiseExplorerClient(conn net.Conn, controlKey key.MachinePublic, machineKey } var buf [1024]byte - if _, err := io.ReadFull(conn, buf[:53]); err != nil { + if _, err := io.ReadFull(conn, buf[:51]); err != nil { return nil, err } // ignore the header for this test, we're only checking the noise // implementation. msg2 := messagebuffer{ - ciphertext: buf[37:53], + ciphertext: buf[35:51], } - copy(msg2.ne[:], buf[5:37]) + copy(msg2.ne[:], buf[3:35]) _, p, valid := RecvMessage(&session, &msg2) if !valid { return nil, errors.New("handshake failed") @@ -158,19 +160,20 @@ func noiseExplorerClient(conn net.Conn, controlKey key.MachinePublic, machineKey } _, msg3 := SendMessage(&session, payload) - setHeader(hdr[:], protocolVersion, msgTypeRecord, len(msg3.ciphertext)) - if _, err := conn.Write(hdr[:]); err != nil { + hdr[0] = msgTypeRecord + binary.BigEndian.PutUint16(hdr[1:3], uint16(len(msg3.ciphertext))) + if _, err := conn.Write(hdr[:3]); err != nil { return nil, err } if _, err := conn.Write(msg3.ciphertext); err != nil { return nil, err } - if _, err := io.ReadFull(conn, buf[:5]); err != nil { + if _, err := io.ReadFull(conn, buf[:3]); err != nil { return nil, err } // Ignore all of the header except the payload length - plen := int(binary.LittleEndian.Uint16(buf[3:5])) + plen := int(binary.BigEndian.Uint16(buf[1:3])) if _, err := io.ReadFull(conn, buf[:plen]); err != nil { return nil, err } @@ -212,7 +215,8 @@ func noiseExplorerServer(conn net.Conn, controlKey key.MachinePrivate, wantMachi _, msg2 := SendMessage(&session, nil) var hdr [headerLen]byte - setHeader(hdr[:], protocolVersion, msgTypeResponse, 48) + hdr[0] = msgTypeResponse + binary.BigEndian.PutUint16(hdr[1:3], 48) if _, err := conn.Write(hdr[:]); err != nil { return nil, err } @@ -223,10 +227,10 @@ func noiseExplorerServer(conn net.Conn, controlKey key.MachinePrivate, wantMachi return nil, err } - if _, err := io.ReadFull(conn, buf[:5]); err != nil { + if _, err := io.ReadFull(conn, buf[:3]); err != nil { return nil, err } - plen := int(binary.LittleEndian.Uint16(buf[3:5])) + plen := int(binary.BigEndian.Uint16(buf[1:3])) if _, err := io.ReadFull(conn, buf[:plen]); err != nil { return nil, err } @@ -240,7 +244,8 @@ func noiseExplorerServer(conn net.Conn, controlKey key.MachinePrivate, wantMachi } _, msg4 := SendMessage(&session, payload) - setHeader(hdr[:], protocolVersion, msgTypeRecord, len(msg4.ciphertext)) + hdr[0] = msgTypeRecord + binary.BigEndian.PutUint16(hdr[1:3], uint16(len(msg4.ciphertext))) if _, err := conn.Write(hdr[:]); err != nil { return nil, err } diff --git a/control/noise/messages.go b/control/noise/messages.go index 1d9ddb37d..658cec314 100644 --- a/control/noise/messages.go +++ b/control/noise/messages.go @@ -6,40 +6,6 @@ import "encoding/binary" -// The control protocol wire format is mostly Noise messages -// encapsulated in a small header describing the payload's type and -// length. The one place we deviate from pure Noise+header is that we -// also support sending an unauthenticated plaintext error as payload, -// to provide an explanation for a connection error that happens -// before the handshake completes. -// -// All frames in our protocol have a 5-byte header: -// -// +------+------+------+------+------+ -// | version | type | length | -// +------+------+------+------+------+ -// -// 2b: protocol version -// 1b: message type -// 2b: payload length (not including the header) -// -// Multibyte values are all big-endian on the wire, as is traditional -// for network protocols. -// -// The protocol version is 2 bytes in order to encourage frequent -// revving of the protocol as needed, without fear of running out of -// version numbers. At minimum, the version number must change -// whenever any particulars of the Noise handshake change -// (e.g. switching from Noise IK to Noise IKpsk1 or Noise XX), and -// when security-critical aspects of the "uppper" protocol (the one -// running inside the established base protocol session) change -// (e.g. how further authentication data is bound to the underlying -// session). - -// headerLen is the size of the header that gets prepended to Noise -// messages. -const headerLen = 5 - const ( // msgTypeInitiation frames carry a Noise IK handshake initiation message. msgTypeInitiation = 1 @@ -54,20 +20,19 @@ msgTypeError = 3 // msgTypeRecord frames carry session data bytes. msgTypeRecord = 4 -) -func setHeader(bs []byte, version uint16, msgType byte, length int) { - binary.LittleEndian.PutUint16(bs[:2], uint16(version)) - bs[2] = msgType - binary.LittleEndian.PutUint16(bs[3:5], uint16(length)) -} -func hdrVersion(bs []byte) uint16 { return binary.LittleEndian.Uint16(bs[:2]) } -func hdrType(bs []byte) byte { return bs[2] } -func hdrLen(bs []byte) int { return int(binary.LittleEndian.Uint16(bs[3:5])) } + // headerLen is the size of the header on all messages except msgTypeInitiation. + headerLen = 3 + // initiationHeaderLen is the size of the header on all msgTypeInitiation messages. + initiationHeaderLen = 5 +) // initiationMessage is the protocol message sent from a client // machine to a control server. // +// 2b: protocol version +// 1b: message type (0x01) +// 2b: payload length (96) // 5b: header (see headerLen for fields) // 32b: client ephemeral public key (cleartext) // 48b: client machine public key (encrypted) @@ -76,41 +41,48 @@ func hdrLen(bs []byte) int { return int(binary.LittleEndian.Uint16(bs[3:5 func mkInitiationMessage() initiationMessage { var ret initiationMessage - setHeader(ret[:], protocolVersion, msgTypeInitiation, len(ret.Payload())) + binary.BigEndian.PutUint16(ret[:2], uint16(protocolVersion)) + ret[2] = msgTypeInitiation + binary.BigEndian.PutUint16(ret[3:5], uint16(len(ret.Payload()))) return ret } -func (m *initiationMessage) Header() []byte { return m[:headerLen] } -func (m *initiationMessage) Payload() []byte { return m[headerLen:] } +func (m *initiationMessage) Header() []byte { return m[:initiationHeaderLen] } +func (m *initiationMessage) Payload() []byte { return m[initiationHeaderLen:] } -func (m *initiationMessage) Version() uint16 { return hdrVersion(m.Header()) } -func (m *initiationMessage) Type() byte { return hdrType(m.Header()) } -func (m *initiationMessage) Length() int { return hdrLen(m.Header()) } +func (m *initiationMessage) Version() uint16 { return binary.BigEndian.Uint16(m[:2]) } +func (m *initiationMessage) Type() byte { return m[2] } +func (m *initiationMessage) Length() int { return int(binary.BigEndian.Uint16(m[3:5])) } -func (m *initiationMessage) EphemeralPub() []byte { return m[headerLen : headerLen+32] } -func (m *initiationMessage) MachinePub() []byte { return m[headerLen+32 : headerLen+32+48] } -func (m *initiationMessage) Tag() []byte { return m[headerLen+32+48:] } +func (m *initiationMessage) EphemeralPub() []byte { + return m[initiationHeaderLen : initiationHeaderLen+32] +} +func (m *initiationMessage) MachinePub() []byte { + return m[initiationHeaderLen+32 : initiationHeaderLen+32+48] +} +func (m *initiationMessage) Tag() []byte { return m[initiationHeaderLen+32+48:] } // responseMessage is the protocol message sent from a control server // to a client machine. // -// 5b: header (see headerLen for fields) +// 1b: message type (0x02) +// 2b: payload length (48) // 32b: control ephemeral public key (cleartext) // 16b: message tag (authenticates the whole message) -type responseMessage [53]byte +type responseMessage [51]byte func mkResponseMessage() responseMessage { var ret responseMessage - setHeader(ret[:], protocolVersion, msgTypeResponse, len(ret.Payload())) + ret[0] = msgTypeResponse + binary.BigEndian.PutUint16(ret[1:], uint16(len(ret.Payload()))) return ret } func (m *responseMessage) Header() []byte { return m[:headerLen] } func (m *responseMessage) Payload() []byte { return m[headerLen:] } -func (m *responseMessage) Version() uint16 { return hdrVersion(m.Header()) } -func (m *responseMessage) Type() byte { return hdrType(m.Header()) } -func (m *responseMessage) Length() int { return hdrLen(m.Header()) } +func (m *responseMessage) Type() byte { return m[0] } +func (m *responseMessage) Length() int { return int(binary.BigEndian.Uint16(m[1:3])) } func (m *responseMessage) EphemeralPub() []byte { return m[headerLen : headerLen+32] } func (m *responseMessage) Tag() []byte { return m[headerLen+32:] }