From f570372b4d8a13e323ba488c9f1a70274a2d4e4f Mon Sep 17 00:00:00 2001 From: David Anderson Date: Thu, 7 Apr 2022 17:43:59 -0700 Subject: [PATCH] control/controlbase: don't enforce a max protocol version at handshake time. Doing so makes development unpleasant, because we have to first break the client by bumping to a version the control server rejects, then upgrade the control server to make it accept the new version. This strict rejection at handshake time is only necessary if we want to blocklist some vulnerable protocol versions in the future. So, switch to a default-permissive stance: until we have such a version that we have to eagerly block early, we'll accept whatever version the client presents, and leave it to the user of controlbase.Conn to make decisions based on that version. Noise still enforces that the client and server *agree* on what protocol version is being used, and the control server still has the option to finish the handshake and then hang up with an in-noise error, rather than abort at the handshake level. Updates #3488 Signed-off-by: David Anderson --- control/controlbase/conn_test.go | 4 ++-- control/controlbase/handshake.go | 22 ++++++---------------- control/controlbase/handshake_test.go | 12 ++++++------ control/controlbase/interop_test.go | 2 +- control/controlhttp/http_test.go | 2 +- control/controlhttp/server.go | 4 ++-- 6 files changed, 18 insertions(+), 28 deletions(-) diff --git a/control/controlbase/conn_test.go b/control/controlbase/conn_test.go index 827d5b9a1..04f3f69b8 100644 --- a/control/controlbase/conn_test.go +++ b/control/controlbase/conn_test.go @@ -206,7 +206,7 @@ func TestConnStd(t *testing.T) { serverErr := make(chan error, 1) go func() { var err error - c2, err = Server(context.Background(), s2, controlKey, testProtocolVersion, nil) + c2, err = Server(context.Background(), s2, controlKey, nil) serverErr <- err }() c1, err = Client(context.Background(), s1, machineKey, controlKey.Public(), testProtocolVersion) @@ -398,7 +398,7 @@ func pairWithConns(t *testing.T, clientConn, serverConn net.Conn) (*Conn, *Conn) ) go func() { var err error - server, err = Server(context.Background(), serverConn, controlKey, testProtocolVersion, nil) + server, err = Server(context.Background(), serverConn, controlKey, nil) serverErr <- err }() diff --git a/control/controlbase/handshake.go b/control/controlbase/handshake.go index 0fb2859b6..b18e08a37 100644 --- a/control/controlbase/handshake.go +++ b/control/controlbase/handshake.go @@ -193,19 +193,13 @@ func continueClientHandshake(ctx context.Context, conn net.Conn, s *symmetricSta // Server initiates a control server handshake, returning the resulting // control connection. // -// maxSupportedVersion is the highest handshake version the server is -// willing to handshake with. The server will handshake with any -// version from 0 to maxSupportedVersion inclusive, the caller should -// inspect conn.Version() to determine what version of the handshake -// was executed. -// // optionalInit can be the client's initial handshake message as // returned by ClientDeferred, or nil in which case the initial // message is read from conn. // // The context deadline, if any, covers the entire handshaking // process. -func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate, maxSupportedVersion uint16, optionalInit []byte) (*Conn, error) { +func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate, optionalInit []byte) (*Conn, error) { if deadline, ok := ctx.Deadline(); ok { if err := conn.SetDeadline(deadline); err != nil { return nil, fmt.Errorf("setting conn deadline: %w", err) @@ -245,15 +239,11 @@ func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate, m } else if _, err := io.ReadFull(conn, init.Header()); err != nil { return nil, err } - // Currently, these versions exclusively indicate what the upper - // RPC protocol understands, the Noise handshake is exactly the - // same in all versions. If that ever changes, this check will - // need to become more complex to handle different kinds of - // handshake. - if init.Version() > maxSupportedVersion { - return nil, sendErr("unsupported handshake version") - } - // Just a rename to make it more obvious what the value is + // Just a rename to make it more obvious what the value is. In the + // current implementation we don't need to block any protocol + // versions at this layer, it's safe to let the handshake proceed + // and then let the caller make decisions based on the agreed-upon + // protocol version. clientVersion := init.Version() if init.Type() != msgTypeInitiation { return nil, sendErr("unexpected handshake message type") diff --git a/control/controlbase/handshake_test.go b/control/controlbase/handshake_test.go index ce28f12e8..755454c1c 100644 --- a/control/controlbase/handshake_test.go +++ b/control/controlbase/handshake_test.go @@ -26,7 +26,7 @@ func TestHandshake(t *testing.T) { ) go func() { var err error - server, err = Server(context.Background(), serverConn, serverKey, testProtocolVersion, nil) + server, err = Server(context.Background(), serverConn, serverKey, nil) serverErr <- err }() @@ -78,7 +78,7 @@ func TestNoReuse(t *testing.T) { ) go func() { var err error - server, err = Server(context.Background(), serverConn, serverKey, testProtocolVersion, nil) + server, err = Server(context.Background(), serverConn, serverKey, nil) serverErr <- err }() @@ -172,7 +172,7 @@ func TestTampering(t *testing.T) { serverErr = make(chan error, 1) ) go func() { - _, err := Server(context.Background(), serverConn, serverKey, testProtocolVersion, nil) + _, err := Server(context.Background(), serverConn, serverKey, nil) // If the server failed, we have to close the Conn to // unblock the client. if err != nil { @@ -200,7 +200,7 @@ func TestTampering(t *testing.T) { serverErr = make(chan error, 1) ) go func() { - _, err := Server(context.Background(), serverConn, serverKey, testProtocolVersion, nil) + _, err := Server(context.Background(), serverConn, serverKey, nil) serverErr <- err }() @@ -225,7 +225,7 @@ func TestTampering(t *testing.T) { serverErr = make(chan error, 1) ) go func() { - server, err := Server(context.Background(), serverConn, serverKey, testProtocolVersion, nil) + server, err := Server(context.Background(), serverConn, serverKey, nil) serverErr <- err _, err = io.WriteString(server, strings.Repeat("a", 14)) serverErr <- err @@ -266,7 +266,7 @@ func TestTampering(t *testing.T) { serverErr = make(chan error, 1) ) go func() { - server, err := Server(context.Background(), serverConn, serverKey, testProtocolVersion, nil) + server, err := Server(context.Background(), serverConn, serverKey, nil) serverErr <- err var bs [100]byte // The server needs a timeout if the tampering is hitting the length header. diff --git a/control/controlbase/interop_test.go b/control/controlbase/interop_test.go index b7e7d15e8..133db8bc5 100644 --- a/control/controlbase/interop_test.go +++ b/control/controlbase/interop_test.go @@ -29,7 +29,7 @@ func TestInteropClient(t *testing.T) { ) go func() { - server, err := Server(context.Background(), s2, controlKey, testProtocolVersion, nil) + server, err := Server(context.Background(), s2, controlKey, nil) serverErr <- err if err != nil { return diff --git a/control/controlhttp/http_test.go b/control/controlhttp/http_test.go index c4b8ddc36..1d2adf124 100644 --- a/control/controlhttp/http_test.go +++ b/control/controlhttp/http_test.go @@ -107,7 +107,7 @@ func testControlHTTP(t *testing.T, proxy proxy) { const testProtocolVersion = 1 sch := make(chan serverResult, 1) handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := AcceptHTTP(context.Background(), w, r, server, testProtocolVersion) + conn, err := AcceptHTTP(context.Background(), w, r, server) if err != nil { log.Print(err) } diff --git a/control/controlhttp/server.go b/control/controlhttp/server.go index 8d7073ffe..0e38da860 100644 --- a/control/controlhttp/server.go +++ b/control/controlhttp/server.go @@ -21,7 +21,7 @@ import ( // // AcceptHTTP always writes an HTTP response to w. The caller must not // attempt their own response after calling AcceptHTTP. -func AcceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, private key.MachinePrivate, maxSupportedVersion uint16) (*controlbase.Conn, error) { +func AcceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, private key.MachinePrivate) (*controlbase.Conn, error) { next := r.Header.Get("Upgrade") if next == "" { http.Error(w, "missing next protocol", http.StatusBadRequest) @@ -63,7 +63,7 @@ func AcceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, pri } conn = netutil.NewDrainBufConn(conn, brw.Reader) - nc, err := controlbase.Server(ctx, conn, private, maxSupportedVersion, init) + nc, err := controlbase.Server(ctx, conn, private, init) if err != nil { conn.Close() return nil, fmt.Errorf("noise handshake failed: %w", err)