control/controlbase: enable asynchronous client handshaking.

With this change, the client can obtain the initial handshake message
separately from the rest of the handshake, for embedding into another
protocol. This enables things like RTT reduction by stuffing the
handshake initiation message into an HTTP header.

Similarly, the server API optionally accepts a pre-read Noise initiation
message, in addition to reading the message directly off a net.Conn.

Updates #3488

Signed-off-by: David Anderson <danderson@tailscale.com>
This commit is contained in:
David Anderson 2022-01-17 15:30:30 -08:00 committed by Dave Anderson
parent 6cd180746f
commit d5a7eabcd0
5 changed files with 84 additions and 35 deletions

View File

@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package noise implements the base transport of the Tailscale 2021
// control protocol.
// Package controlbase implements the base transport of the Tailscale
// 2021 control protocol.
//
// The base transport implements Noise IK, instantiated with
// Curve25519, ChaCha20Poly1305 and BLAKE2s.

View File

@ -202,7 +202,7 @@ func TestConnStd(t *testing.T) {
serverErr := make(chan error, 1)
go func() {
var err error
c2, err = Server(context.Background(), s2, controlKey)
c2, err = Server(context.Background(), s2, controlKey, nil)
serverErr <- err
}()
c1, err = Client(context.Background(), s1, machineKey, controlKey.Public())
@ -319,7 +319,7 @@ func pairWithConns(t *testing.T, clientConn, serverConn net.Conn) (*Conn, *Conn)
)
go func() {
var err error
server, err = Server(context.Background(), serverConn, controlKey)
server, err = Server(context.Background(), serverConn, controlKey, nil)
serverErr <- err
}()

View File

@ -50,21 +50,23 @@ func protocolVersionPrologue(version uint16) []byte {
return strconv.AppendUint(ret, uint64(version), 10)
}
// Client initiates a control client handshake, returning the resulting
// control connection.
//
// The context deadline, if any, covers the entire handshaking
// process. Any preexisting Conn deadline is removed.
func Client(ctx context.Context, conn net.Conn, machineKey key.MachinePrivate, controlKey key.MachinePublic) (*Conn, error) {
if deadline, ok := ctx.Deadline(); ok {
if err := conn.SetDeadline(deadline); err != nil {
return nil, fmt.Errorf("setting conn deadline: %w", err)
}
defer func() {
conn.SetDeadline(time.Time{})
}()
}
// HandshakeContinuation upgrades a net.Conn to a Conn. The net.Conn
// is assumed to have already sent the client>server handshake
// initiation message.
type HandshakeContinuation func(context.Context, net.Conn) (*Conn, error)
// ClientDeferred initiates a control client handshake, returning the
// initial message to send to the server and a continuation to
// finalize the handshake.
//
// ClientDeferred is split in this way for RTT reduction: we run this
// protocol after negotiating a protocol switch from HTTP/HTTPS. If we
// completely serialized the negotiation followed by the handshake,
// we'd pay an extra RTT to transmit the handshake initiation after
// protocol switching. By splitting the handshake into an initial
// message and a continuation, we can embed the handshake initiation
// into the HTTP protocol switching request and avoid a bit of delay.
func ClientDeferred(machineKey key.MachinePrivate, controlKey key.MachinePublic) (initialHandshake []byte, continueHandshake HandshakeContinuation, err error) {
var s symmetricState
s.Initialize()
@ -83,18 +85,53 @@ func Client(ctx context.Context, conn net.Conn, machineKey key.MachinePrivate, c
s.MixHash(machineEphemeralPub.UntypedBytes())
cipher, err := s.MixDH(machineEphemeral, controlKey)
if err != nil {
return nil, fmt.Errorf("computing es: %w", err)
return nil, nil, fmt.Errorf("computing es: %w", err)
}
machineKeyPub := machineKey.Public()
s.EncryptAndHash(cipher, init.MachinePub(), machineKeyPub.UntypedBytes())
cipher, err = s.MixDH(machineKey, controlKey)
if err != nil {
return nil, fmt.Errorf("computing ss: %w", err)
return nil, nil, fmt.Errorf("computing ss: %w", err)
}
s.EncryptAndHash(cipher, init.Tag(), nil) // empty message payload
if _, err := conn.Write(init[:]); err != nil {
return nil, fmt.Errorf("writing initiation: %w", err)
cont := func(ctx context.Context, conn net.Conn) (*Conn, error) {
return continueClientHandshake(ctx, conn, &s, machineKey, machineEphemeral, controlKey)
}
return init[:], cont, nil
}
// Client wraps ClientDeferred and immediately invokes the returned
// continuation with conn.
//
// This is a helper for when you don't need the fancy
// continuation-style handshake, and just want to synchronously
// upgrade a net.Conn to a secure transport.
func Client(ctx context.Context, conn net.Conn, machineKey key.MachinePrivate, controlKey key.MachinePublic) (*Conn, error) {
init, cont, err := ClientDeferred(machineKey, controlKey)
if err != nil {
return nil, err
}
if _, err := conn.Write(init); err != nil {
return nil, err
}
return cont(ctx, conn)
}
func continueClientHandshake(ctx context.Context, conn net.Conn, s *symmetricState, machineKey, machineEphemeral key.MachinePrivate, controlKey key.MachinePublic) (*Conn, error) {
// No matter what, this function can only run once per s. Ensure
// attempted reuse causes a panic.
defer func() {
s.finished = true
}()
if deadline, ok := ctx.Deadline(); ok {
if err := conn.SetDeadline(deadline); err != nil {
return nil, fmt.Errorf("setting conn deadline: %w", err)
}
defer func() {
conn.SetDeadline(time.Time{})
}()
}
// Read in the payload and look for errors/protocol violations from the server.
@ -122,10 +159,10 @@ func Client(ctx context.Context, conn net.Conn, machineKey key.MachinePrivate, c
// <- e, ee, se
controlEphemeralPub := key.MachinePublicFromRaw32(mem.B(resp.EphemeralPub()))
s.MixHash(controlEphemeralPub.UntypedBytes())
if _, err = s.MixDH(machineEphemeral, controlEphemeralPub); err != nil {
if _, err := s.MixDH(machineEphemeral, controlEphemeralPub); err != nil {
return nil, fmt.Errorf("computing ee: %w", err)
}
cipher, err = s.MixDH(machineKey, controlEphemeralPub)
cipher, err := s.MixDH(machineKey, controlEphemeralPub)
if err != nil {
return nil, fmt.Errorf("computing se: %w", err)
}
@ -156,9 +193,13 @@ func Client(ctx context.Context, conn net.Conn, machineKey key.MachinePrivate, c
// Server initiates a control server handshake, returning the resulting
// control connection.
//
// optionalInit can be the client's initial handshake message as
// returned by ClientDeferred, or nil in which case the initial
// message is read from conn.
//
// The context deadline, if any, covers the entire handshaking
// process.
func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate) (*Conn, error) {
func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate, optionalInit []byte) (*Conn, error) {
if deadline, ok := ctx.Deadline(); ok {
if err := conn.SetDeadline(deadline); err != nil {
return nil, fmt.Errorf("setting conn deadline: %w", err)
@ -190,7 +231,12 @@ func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate) (
s.Initialize()
var init initiationMessage
if _, err := io.ReadFull(conn, init.Header()); err != nil {
if optionalInit != nil {
if len(optionalInit) != len(init) {
return nil, sendErr("wrong handshake initiation size")
}
copy(init[:], optionalInit)
} else if _, err := io.ReadFull(conn, init.Header()); err != nil {
return nil, err
}
if init.Version() != protocolVersion {
@ -202,8 +248,11 @@ func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate) (
if init.Length() != len(init.Payload()) {
return nil, sendErr("wrong handshake initiation length")
}
if _, err := io.ReadFull(conn, init.Payload()); err != nil {
return nil, err
// if optionalInit was provided, we have the payload already.
if optionalInit == nil {
if _, err := io.ReadFull(conn, init.Payload()); err != nil {
return nil, err
}
}
// prologue. Can only do this once we at least think the client is

View File

@ -26,7 +26,7 @@ func TestHandshake(t *testing.T) {
)
go func() {
var err error
server, err = Server(context.Background(), serverConn, serverKey)
server, err = Server(context.Background(), serverConn, serverKey, nil)
serverErr <- err
}()
@ -78,7 +78,7 @@ func TestNoReuse(t *testing.T) {
)
go func() {
var err error
server, err = Server(context.Background(), serverConn, serverKey)
server, err = Server(context.Background(), serverConn, serverKey, nil)
serverErr <- err
}()
@ -172,7 +172,7 @@ func TestTampering(t *testing.T) {
serverErr = make(chan error, 1)
)
go func() {
_, err := Server(context.Background(), serverConn, serverKey)
_, err := Server(context.Background(), serverConn, serverKey, nil)
// If the server failed, we have to close the Conn to
// unblock the client.
if err != nil {
@ -200,7 +200,7 @@ func TestTampering(t *testing.T) {
serverErr = make(chan error, 1)
)
go func() {
_, err := Server(context.Background(), serverConn, serverKey)
_, err := Server(context.Background(), serverConn, serverKey, nil)
serverErr <- err
}()
@ -225,7 +225,7 @@ func TestTampering(t *testing.T) {
serverErr = make(chan error, 1)
)
go func() {
server, err := Server(context.Background(), serverConn, serverKey)
server, err := Server(context.Background(), serverConn, serverKey, nil)
serverErr <- err
_, err = io.WriteString(server, strings.Repeat("a", 14))
serverErr <- err
@ -266,7 +266,7 @@ func TestTampering(t *testing.T) {
serverErr = make(chan error, 1)
)
go func() {
server, err := Server(context.Background(), serverConn, serverKey)
server, err := Server(context.Background(), serverConn, serverKey, nil)
serverErr <- err
var bs [100]byte
// The server needs a timeout if the tampering is hitting the length header.

View File

@ -29,7 +29,7 @@ func TestInteropClient(t *testing.T) {
)
go func() {
server, err := Server(context.Background(), s2, controlKey)
server, err := Server(context.Background(), s2, controlKey, nil)
serverErr <- err
if err != nil {
return