diff --git a/control/controlbase/conn.go b/control/controlbase/conn.go index 0e28f4d08..aba8d755e 100644 --- a/control/controlbase/conn.go +++ b/control/controlbase/conn.go @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Package noise implements the base transport of the Tailscale 2021 -// control protocol. +// Package controlbase implements the base transport of the Tailscale +// 2021 control protocol. // // The base transport implements Noise IK, instantiated with // Curve25519, ChaCha20Poly1305 and BLAKE2s. diff --git a/control/controlbase/conn_test.go b/control/controlbase/conn_test.go index a8328bd0b..c0dfa9940 100644 --- a/control/controlbase/conn_test.go +++ b/control/controlbase/conn_test.go @@ -202,7 +202,7 @@ func TestConnStd(t *testing.T) { serverErr := make(chan error, 1) go func() { var err error - c2, err = Server(context.Background(), s2, controlKey) + c2, err = Server(context.Background(), s2, controlKey, nil) serverErr <- err }() c1, err = Client(context.Background(), s1, machineKey, controlKey.Public()) @@ -319,7 +319,7 @@ func pairWithConns(t *testing.T, clientConn, serverConn net.Conn) (*Conn, *Conn) ) go func() { var err error - server, err = Server(context.Background(), serverConn, controlKey) + server, err = Server(context.Background(), serverConn, controlKey, nil) serverErr <- err }() diff --git a/control/controlbase/handshake.go b/control/controlbase/handshake.go index 57606581c..393576ee8 100644 --- a/control/controlbase/handshake.go +++ b/control/controlbase/handshake.go @@ -50,21 +50,23 @@ func protocolVersionPrologue(version uint16) []byte { return strconv.AppendUint(ret, uint64(version), 10) } -// Client initiates a control client handshake, returning the resulting -// control connection. -// -// The context deadline, if any, covers the entire handshaking -// process. Any preexisting Conn deadline is removed. -func Client(ctx context.Context, conn net.Conn, machineKey key.MachinePrivate, controlKey key.MachinePublic) (*Conn, error) { - if deadline, ok := ctx.Deadline(); ok { - if err := conn.SetDeadline(deadline); err != nil { - return nil, fmt.Errorf("setting conn deadline: %w", err) - } - defer func() { - conn.SetDeadline(time.Time{}) - }() - } +// HandshakeContinuation upgrades a net.Conn to a Conn. The net.Conn +// is assumed to have already sent the client>server handshake +// initiation message. +type HandshakeContinuation func(context.Context, net.Conn) (*Conn, error) +// ClientDeferred initiates a control client handshake, returning the +// initial message to send to the server and a continuation to +// finalize the handshake. +// +// ClientDeferred is split in this way for RTT reduction: we run this +// protocol after negotiating a protocol switch from HTTP/HTTPS. If we +// completely serialized the negotiation followed by the handshake, +// we'd pay an extra RTT to transmit the handshake initiation after +// protocol switching. By splitting the handshake into an initial +// message and a continuation, we can embed the handshake initiation +// into the HTTP protocol switching request and avoid a bit of delay. +func ClientDeferred(machineKey key.MachinePrivate, controlKey key.MachinePublic) (initialHandshake []byte, continueHandshake HandshakeContinuation, err error) { var s symmetricState s.Initialize() @@ -83,18 +85,53 @@ func Client(ctx context.Context, conn net.Conn, machineKey key.MachinePrivate, c s.MixHash(machineEphemeralPub.UntypedBytes()) cipher, err := s.MixDH(machineEphemeral, controlKey) if err != nil { - return nil, fmt.Errorf("computing es: %w", err) + return nil, nil, fmt.Errorf("computing es: %w", err) } machineKeyPub := machineKey.Public() s.EncryptAndHash(cipher, init.MachinePub(), machineKeyPub.UntypedBytes()) cipher, err = s.MixDH(machineKey, controlKey) if err != nil { - return nil, fmt.Errorf("computing ss: %w", err) + return nil, nil, fmt.Errorf("computing ss: %w", err) } s.EncryptAndHash(cipher, init.Tag(), nil) // empty message payload - if _, err := conn.Write(init[:]); err != nil { - return nil, fmt.Errorf("writing initiation: %w", err) + cont := func(ctx context.Context, conn net.Conn) (*Conn, error) { + return continueClientHandshake(ctx, conn, &s, machineKey, machineEphemeral, controlKey) + } + return init[:], cont, nil +} + +// Client wraps ClientDeferred and immediately invokes the returned +// continuation with conn. +// +// This is a helper for when you don't need the fancy +// continuation-style handshake, and just want to synchronously +// upgrade a net.Conn to a secure transport. +func Client(ctx context.Context, conn net.Conn, machineKey key.MachinePrivate, controlKey key.MachinePublic) (*Conn, error) { + init, cont, err := ClientDeferred(machineKey, controlKey) + if err != nil { + return nil, err + } + if _, err := conn.Write(init); err != nil { + return nil, err + } + return cont(ctx, conn) +} + +func continueClientHandshake(ctx context.Context, conn net.Conn, s *symmetricState, machineKey, machineEphemeral key.MachinePrivate, controlKey key.MachinePublic) (*Conn, error) { + // No matter what, this function can only run once per s. Ensure + // attempted reuse causes a panic. + defer func() { + s.finished = true + }() + + if deadline, ok := ctx.Deadline(); ok { + if err := conn.SetDeadline(deadline); err != nil { + return nil, fmt.Errorf("setting conn deadline: %w", err) + } + defer func() { + conn.SetDeadline(time.Time{}) + }() } // Read in the payload and look for errors/protocol violations from the server. @@ -122,10 +159,10 @@ func Client(ctx context.Context, conn net.Conn, machineKey key.MachinePrivate, c // <- e, ee, se controlEphemeralPub := key.MachinePublicFromRaw32(mem.B(resp.EphemeralPub())) s.MixHash(controlEphemeralPub.UntypedBytes()) - if _, err = s.MixDH(machineEphemeral, controlEphemeralPub); err != nil { + if _, err := s.MixDH(machineEphemeral, controlEphemeralPub); err != nil { return nil, fmt.Errorf("computing ee: %w", err) } - cipher, err = s.MixDH(machineKey, controlEphemeralPub) + cipher, err := s.MixDH(machineKey, controlEphemeralPub) if err != nil { return nil, fmt.Errorf("computing se: %w", err) } @@ -156,9 +193,13 @@ func Client(ctx context.Context, conn net.Conn, machineKey key.MachinePrivate, c // Server initiates a control server handshake, returning the resulting // control connection. // +// optionalInit can be the client's initial handshake message as +// returned by ClientDeferred, or nil in which case the initial +// message is read from conn. +// // The context deadline, if any, covers the entire handshaking // process. -func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate) (*Conn, error) { +func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate, optionalInit []byte) (*Conn, error) { if deadline, ok := ctx.Deadline(); ok { if err := conn.SetDeadline(deadline); err != nil { return nil, fmt.Errorf("setting conn deadline: %w", err) @@ -190,7 +231,12 @@ func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate) ( s.Initialize() var init initiationMessage - if _, err := io.ReadFull(conn, init.Header()); err != nil { + if optionalInit != nil { + if len(optionalInit) != len(init) { + return nil, sendErr("wrong handshake initiation size") + } + copy(init[:], optionalInit) + } else if _, err := io.ReadFull(conn, init.Header()); err != nil { return nil, err } if init.Version() != protocolVersion { @@ -202,8 +248,11 @@ func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate) ( if init.Length() != len(init.Payload()) { return nil, sendErr("wrong handshake initiation length") } - if _, err := io.ReadFull(conn, init.Payload()); err != nil { - return nil, err + // if optionalInit was provided, we have the payload already. + if optionalInit == nil { + if _, err := io.ReadFull(conn, init.Payload()); err != nil { + return nil, err + } } // prologue. Can only do this once we at least think the client is diff --git a/control/controlbase/handshake_test.go b/control/controlbase/handshake_test.go index a5664c11a..9cdc6f5f2 100644 --- a/control/controlbase/handshake_test.go +++ b/control/controlbase/handshake_test.go @@ -26,7 +26,7 @@ func TestHandshake(t *testing.T) { ) go func() { var err error - server, err = Server(context.Background(), serverConn, serverKey) + server, err = Server(context.Background(), serverConn, serverKey, nil) serverErr <- err }() @@ -78,7 +78,7 @@ func TestNoReuse(t *testing.T) { ) go func() { var err error - server, err = Server(context.Background(), serverConn, serverKey) + server, err = Server(context.Background(), serverConn, serverKey, nil) serverErr <- err }() @@ -172,7 +172,7 @@ func TestTampering(t *testing.T) { serverErr = make(chan error, 1) ) go func() { - _, err := Server(context.Background(), serverConn, serverKey) + _, err := Server(context.Background(), serverConn, serverKey, nil) // If the server failed, we have to close the Conn to // unblock the client. if err != nil { @@ -200,7 +200,7 @@ func TestTampering(t *testing.T) { serverErr = make(chan error, 1) ) go func() { - _, err := Server(context.Background(), serverConn, serverKey) + _, err := Server(context.Background(), serverConn, serverKey, nil) serverErr <- err }() @@ -225,7 +225,7 @@ func TestTampering(t *testing.T) { serverErr = make(chan error, 1) ) go func() { - server, err := Server(context.Background(), serverConn, serverKey) + server, err := Server(context.Background(), serverConn, serverKey, nil) serverErr <- err _, err = io.WriteString(server, strings.Repeat("a", 14)) serverErr <- err @@ -266,7 +266,7 @@ func TestTampering(t *testing.T) { serverErr = make(chan error, 1) ) go func() { - server, err := Server(context.Background(), serverConn, serverKey) + server, err := Server(context.Background(), serverConn, serverKey, nil) serverErr <- err var bs [100]byte // The server needs a timeout if the tampering is hitting the length header. diff --git a/control/controlbase/interop_test.go b/control/controlbase/interop_test.go index 04bd7f41d..3417639fe 100644 --- a/control/controlbase/interop_test.go +++ b/control/controlbase/interop_test.go @@ -29,7 +29,7 @@ func TestInteropClient(t *testing.T) { ) go func() { - server, err := Server(context.Background(), s2, controlKey) + server, err := Server(context.Background(), s2, controlKey, nil) serverErr <- err if err != nil { return