control/controlbase: don't enforce a max protocol version at handshake time.

Doing so makes development unpleasant, because we have to first break the
client by bumping to a version the control server rejects, then upgrade
the control server to make it accept the new version.

This strict rejection at handshake time is only necessary if we want to
blocklist some vulnerable protocol versions in the future. So, switch
to a default-permissive stance: until we have such a version that we
have to eagerly block early, we'll accept whatever version the client
presents, and leave it to the user of controlbase.Conn to make decisions
based on that version.

Noise still enforces that the client and server *agree* on what protocol
version is being used, and the control server still has the option to
finish the handshake and then hang up with an in-noise error, rather
than abort at the handshake level.

Updates #3488

Signed-off-by: David Anderson <danderson@tailscale.com>
This commit is contained in:
David Anderson 2022-04-07 17:43:59 -07:00 committed by Dave Anderson
parent c6ac29bcc4
commit f570372b4d
6 changed files with 18 additions and 28 deletions

View File

@ -206,7 +206,7 @@ func TestConnStd(t *testing.T) {
serverErr := make(chan error, 1) serverErr := make(chan error, 1)
go func() { go func() {
var err error var err error
c2, err = Server(context.Background(), s2, controlKey, testProtocolVersion, nil) c2, err = Server(context.Background(), s2, controlKey, nil)
serverErr <- err serverErr <- err
}() }()
c1, err = Client(context.Background(), s1, machineKey, controlKey.Public(), testProtocolVersion) c1, err = Client(context.Background(), s1, machineKey, controlKey.Public(), testProtocolVersion)
@ -398,7 +398,7 @@ func pairWithConns(t *testing.T, clientConn, serverConn net.Conn) (*Conn, *Conn)
) )
go func() { go func() {
var err error var err error
server, err = Server(context.Background(), serverConn, controlKey, testProtocolVersion, nil) server, err = Server(context.Background(), serverConn, controlKey, nil)
serverErr <- err serverErr <- err
}() }()

View File

@ -193,19 +193,13 @@ func continueClientHandshake(ctx context.Context, conn net.Conn, s *symmetricSta
// Server initiates a control server handshake, returning the resulting // Server initiates a control server handshake, returning the resulting
// control connection. // control connection.
// //
// maxSupportedVersion is the highest handshake version the server is
// willing to handshake with. The server will handshake with any
// version from 0 to maxSupportedVersion inclusive, the caller should
// inspect conn.Version() to determine what version of the handshake
// was executed.
//
// optionalInit can be the client's initial handshake message as // optionalInit can be the client's initial handshake message as
// returned by ClientDeferred, or nil in which case the initial // returned by ClientDeferred, or nil in which case the initial
// message is read from conn. // message is read from conn.
// //
// The context deadline, if any, covers the entire handshaking // The context deadline, if any, covers the entire handshaking
// process. // process.
func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate, maxSupportedVersion uint16, optionalInit []byte) (*Conn, error) { func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate, optionalInit []byte) (*Conn, error) {
if deadline, ok := ctx.Deadline(); ok { if deadline, ok := ctx.Deadline(); ok {
if err := conn.SetDeadline(deadline); err != nil { if err := conn.SetDeadline(deadline); err != nil {
return nil, fmt.Errorf("setting conn deadline: %w", err) return nil, fmt.Errorf("setting conn deadline: %w", err)
@ -245,15 +239,11 @@ func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate, m
} else if _, err := io.ReadFull(conn, init.Header()); err != nil { } else if _, err := io.ReadFull(conn, init.Header()); err != nil {
return nil, err return nil, err
} }
// Currently, these versions exclusively indicate what the upper // Just a rename to make it more obvious what the value is. In the
// RPC protocol understands, the Noise handshake is exactly the // current implementation we don't need to block any protocol
// same in all versions. If that ever changes, this check will // versions at this layer, it's safe to let the handshake proceed
// need to become more complex to handle different kinds of // and then let the caller make decisions based on the agreed-upon
// handshake. // protocol version.
if init.Version() > maxSupportedVersion {
return nil, sendErr("unsupported handshake version")
}
// Just a rename to make it more obvious what the value is
clientVersion := init.Version() clientVersion := init.Version()
if init.Type() != msgTypeInitiation { if init.Type() != msgTypeInitiation {
return nil, sendErr("unexpected handshake message type") return nil, sendErr("unexpected handshake message type")

View File

@ -26,7 +26,7 @@ func TestHandshake(t *testing.T) {
) )
go func() { go func() {
var err error var err error
server, err = Server(context.Background(), serverConn, serverKey, testProtocolVersion, nil) server, err = Server(context.Background(), serverConn, serverKey, nil)
serverErr <- err serverErr <- err
}() }()
@ -78,7 +78,7 @@ func TestNoReuse(t *testing.T) {
) )
go func() { go func() {
var err error var err error
server, err = Server(context.Background(), serverConn, serverKey, testProtocolVersion, nil) server, err = Server(context.Background(), serverConn, serverKey, nil)
serverErr <- err serverErr <- err
}() }()
@ -172,7 +172,7 @@ func TestTampering(t *testing.T) {
serverErr = make(chan error, 1) serverErr = make(chan error, 1)
) )
go func() { go func() {
_, err := Server(context.Background(), serverConn, serverKey, testProtocolVersion, nil) _, err := Server(context.Background(), serverConn, serverKey, nil)
// If the server failed, we have to close the Conn to // If the server failed, we have to close the Conn to
// unblock the client. // unblock the client.
if err != nil { if err != nil {
@ -200,7 +200,7 @@ func TestTampering(t *testing.T) {
serverErr = make(chan error, 1) serverErr = make(chan error, 1)
) )
go func() { go func() {
_, err := Server(context.Background(), serverConn, serverKey, testProtocolVersion, nil) _, err := Server(context.Background(), serverConn, serverKey, nil)
serverErr <- err serverErr <- err
}() }()
@ -225,7 +225,7 @@ func TestTampering(t *testing.T) {
serverErr = make(chan error, 1) serverErr = make(chan error, 1)
) )
go func() { go func() {
server, err := Server(context.Background(), serverConn, serverKey, testProtocolVersion, nil) server, err := Server(context.Background(), serverConn, serverKey, nil)
serverErr <- err serverErr <- err
_, err = io.WriteString(server, strings.Repeat("a", 14)) _, err = io.WriteString(server, strings.Repeat("a", 14))
serverErr <- err serverErr <- err
@ -266,7 +266,7 @@ func TestTampering(t *testing.T) {
serverErr = make(chan error, 1) serverErr = make(chan error, 1)
) )
go func() { go func() {
server, err := Server(context.Background(), serverConn, serverKey, testProtocolVersion, nil) server, err := Server(context.Background(), serverConn, serverKey, nil)
serverErr <- err serverErr <- err
var bs [100]byte var bs [100]byte
// The server needs a timeout if the tampering is hitting the length header. // The server needs a timeout if the tampering is hitting the length header.

View File

@ -29,7 +29,7 @@ func TestInteropClient(t *testing.T) {
) )
go func() { go func() {
server, err := Server(context.Background(), s2, controlKey, testProtocolVersion, nil) server, err := Server(context.Background(), s2, controlKey, nil)
serverErr <- err serverErr <- err
if err != nil { if err != nil {
return return

View File

@ -107,7 +107,7 @@ func testControlHTTP(t *testing.T, proxy proxy) {
const testProtocolVersion = 1 const testProtocolVersion = 1
sch := make(chan serverResult, 1) sch := make(chan serverResult, 1)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := AcceptHTTP(context.Background(), w, r, server, testProtocolVersion) conn, err := AcceptHTTP(context.Background(), w, r, server)
if err != nil { if err != nil {
log.Print(err) log.Print(err)
} }

View File

@ -21,7 +21,7 @@ import (
// //
// AcceptHTTP always writes an HTTP response to w. The caller must not // AcceptHTTP always writes an HTTP response to w. The caller must not
// attempt their own response after calling AcceptHTTP. // attempt their own response after calling AcceptHTTP.
func AcceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, private key.MachinePrivate, maxSupportedVersion uint16) (*controlbase.Conn, error) { func AcceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, private key.MachinePrivate) (*controlbase.Conn, error) {
next := r.Header.Get("Upgrade") next := r.Header.Get("Upgrade")
if next == "" { if next == "" {
http.Error(w, "missing next protocol", http.StatusBadRequest) http.Error(w, "missing next protocol", http.StatusBadRequest)
@ -63,7 +63,7 @@ func AcceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, pri
} }
conn = netutil.NewDrainBufConn(conn, brw.Reader) conn = netutil.NewDrainBufConn(conn, brw.Reader)
nc, err := controlbase.Server(ctx, conn, private, maxSupportedVersion, init) nc, err := controlbase.Server(ctx, conn, private, init)
if err != nil { if err != nil {
conn.Close() conn.Close()
return nil, fmt.Errorf("noise handshake failed: %w", err) return nil, fmt.Errorf("noise handshake failed: %w", err)