mirror of
https://github.com/tailscale/tailscale.git
synced 2024-12-01 14:05:39 +00:00
control/controlbase: make the protocol version number selectable.
This is so that we can plumb our client capability version through the protocol as the Noise version. The capability version increments more frequently than strictly required (the Noise version only needs to change when cryptographically-significant changes are made to the protocol, whereas the capability version also indicates changes in non-cryptographically-significant parts of the protocol), but this gives us a safe pre-auth way to determine if the client supports future protocol features, while still relying on Noise's strong assurance that the client and server have agreed on the same version. Currently, the server executes the same protocol regardless of the version number, and just presents the version to the caller so they can do capability-based things in the upper RPC protocol. In future, we may add a ratchet to disallow obsolete protocols, or vary the Noise handshake behavior based on requested version. Updates #3488 Signed-off-by: David Anderson <danderson@tailscale.com>
This commit is contained in:
parent
be861797b4
commit
02ad987e24
@ -26,6 +26,8 @@
|
|||||||
"tailscale.com/types/key"
|
"tailscale.com/types/key"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const testProtocolVersion = 1
|
||||||
|
|
||||||
func TestMessageSize(t *testing.T) {
|
func TestMessageSize(t *testing.T) {
|
||||||
// This test is a regression guard against someone looking at
|
// This test is a regression guard against someone looking at
|
||||||
// maxCiphertextSize, going "huh, we could be more efficient if it
|
// maxCiphertextSize, going "huh, we could be more efficient if it
|
||||||
@ -204,10 +206,10 @@ func TestConnStd(t *testing.T) {
|
|||||||
serverErr := make(chan error, 1)
|
serverErr := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
var err error
|
var err error
|
||||||
c2, err = Server(context.Background(), s2, controlKey, nil)
|
c2, err = Server(context.Background(), s2, controlKey, testProtocolVersion, nil)
|
||||||
serverErr <- err
|
serverErr <- err
|
||||||
}()
|
}()
|
||||||
c1, err = Client(context.Background(), s1, machineKey, controlKey.Public())
|
c1, err = Client(context.Background(), s1, machineKey, controlKey.Public(), testProtocolVersion)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s1.Close()
|
s1.Close()
|
||||||
s2.Close()
|
s2.Close()
|
||||||
@ -396,11 +398,11 @@ func pairWithConns(t *testing.T, clientConn, serverConn net.Conn) (*Conn, *Conn)
|
|||||||
)
|
)
|
||||||
go func() {
|
go func() {
|
||||||
var err error
|
var err error
|
||||||
server, err = Server(context.Background(), serverConn, controlKey, nil)
|
server, err = Server(context.Background(), serverConn, controlKey, testProtocolVersion, nil)
|
||||||
serverErr <- err
|
serverErr <- err
|
||||||
}()
|
}()
|
||||||
|
|
||||||
client, err := Client(context.Background(), clientConn, machineKey, controlKey.Public())
|
client, err := Client(context.Background(), clientConn, machineKey, controlKey.Public(), testProtocolVersion)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("client connection failed: %v", err)
|
t.Fatalf("client connection failed: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -32,7 +32,7 @@
|
|||||||
protocolName = "Noise_IK_25519_ChaChaPoly_BLAKE2s"
|
protocolName = "Noise_IK_25519_ChaChaPoly_BLAKE2s"
|
||||||
// protocolVersion is the version of the control protocol that
|
// protocolVersion is the version of the control protocol that
|
||||||
// Client will use when initiating a handshake.
|
// Client will use when initiating a handshake.
|
||||||
protocolVersion uint16 = 1
|
//protocolVersion uint16 = 1
|
||||||
// protocolVersionPrefix is the name portion of the protocol
|
// protocolVersionPrefix is the name portion of the protocol
|
||||||
// name+version string that gets mixed into the handshake as a
|
// name+version string that gets mixed into the handshake as a
|
||||||
// prologue.
|
// prologue.
|
||||||
@ -66,7 +66,7 @@ func protocolVersionPrologue(version uint16) []byte {
|
|||||||
// protocol switching. By splitting the handshake into an initial
|
// protocol switching. By splitting the handshake into an initial
|
||||||
// message and a continuation, we can embed the handshake initiation
|
// message and a continuation, we can embed the handshake initiation
|
||||||
// into the HTTP protocol switching request and avoid a bit of delay.
|
// into the HTTP protocol switching request and avoid a bit of delay.
|
||||||
func ClientDeferred(machineKey key.MachinePrivate, controlKey key.MachinePublic) (initialHandshake []byte, continueHandshake HandshakeContinuation, err error) {
|
func ClientDeferred(machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (initialHandshake []byte, continueHandshake HandshakeContinuation, err error) {
|
||||||
var s symmetricState
|
var s symmetricState
|
||||||
s.Initialize()
|
s.Initialize()
|
||||||
|
|
||||||
@ -78,7 +78,7 @@ func ClientDeferred(machineKey key.MachinePrivate, controlKey key.MachinePublic)
|
|||||||
s.MixHash(controlKey.UntypedBytes())
|
s.MixHash(controlKey.UntypedBytes())
|
||||||
|
|
||||||
// -> e, es, s, ss
|
// -> e, es, s, ss
|
||||||
init := mkInitiationMessage()
|
init := mkInitiationMessage(protocolVersion)
|
||||||
machineEphemeral := key.NewMachine()
|
machineEphemeral := key.NewMachine()
|
||||||
machineEphemeralPub := machineEphemeral.Public()
|
machineEphemeralPub := machineEphemeral.Public()
|
||||||
copy(init.EphemeralPub(), machineEphemeralPub.UntypedBytes())
|
copy(init.EphemeralPub(), machineEphemeralPub.UntypedBytes())
|
||||||
@ -96,7 +96,7 @@ func ClientDeferred(machineKey key.MachinePrivate, controlKey key.MachinePublic)
|
|||||||
s.EncryptAndHash(cipher, init.Tag(), nil) // empty message payload
|
s.EncryptAndHash(cipher, init.Tag(), nil) // empty message payload
|
||||||
|
|
||||||
cont := func(ctx context.Context, conn net.Conn) (*Conn, error) {
|
cont := func(ctx context.Context, conn net.Conn) (*Conn, error) {
|
||||||
return continueClientHandshake(ctx, conn, &s, machineKey, machineEphemeral, controlKey)
|
return continueClientHandshake(ctx, conn, &s, machineKey, machineEphemeral, controlKey, protocolVersion)
|
||||||
}
|
}
|
||||||
return init[:], cont, nil
|
return init[:], cont, nil
|
||||||
}
|
}
|
||||||
@ -107,8 +107,8 @@ func ClientDeferred(machineKey key.MachinePrivate, controlKey key.MachinePublic)
|
|||||||
// This is a helper for when you don't need the fancy
|
// This is a helper for when you don't need the fancy
|
||||||
// continuation-style handshake, and just want to synchronously
|
// continuation-style handshake, and just want to synchronously
|
||||||
// upgrade a net.Conn to a secure transport.
|
// upgrade a net.Conn to a secure transport.
|
||||||
func Client(ctx context.Context, conn net.Conn, machineKey key.MachinePrivate, controlKey key.MachinePublic) (*Conn, error) {
|
func Client(ctx context.Context, conn net.Conn, machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (*Conn, error) {
|
||||||
init, cont, err := ClientDeferred(machineKey, controlKey)
|
init, cont, err := ClientDeferred(machineKey, controlKey, protocolVersion)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -118,7 +118,7 @@ func Client(ctx context.Context, conn net.Conn, machineKey key.MachinePrivate, c
|
|||||||
return cont(ctx, conn)
|
return cont(ctx, conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func continueClientHandshake(ctx context.Context, conn net.Conn, s *symmetricState, machineKey, machineEphemeral key.MachinePrivate, controlKey key.MachinePublic) (*Conn, error) {
|
func continueClientHandshake(ctx context.Context, conn net.Conn, s *symmetricState, machineKey, machineEphemeral key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (*Conn, error) {
|
||||||
// No matter what, this function can only run once per s. Ensure
|
// No matter what, this function can only run once per s. Ensure
|
||||||
// attempted reuse causes a panic.
|
// attempted reuse causes a panic.
|
||||||
defer func() {
|
defer func() {
|
||||||
@ -193,13 +193,19 @@ func continueClientHandshake(ctx context.Context, conn net.Conn, s *symmetricSta
|
|||||||
// Server initiates a control server handshake, returning the resulting
|
// Server initiates a control server handshake, returning the resulting
|
||||||
// control connection.
|
// control connection.
|
||||||
//
|
//
|
||||||
|
// maxSupportedVersion is the highest handshake version the server is
|
||||||
|
// willing to handshake with. The server will handshake with any
|
||||||
|
// version from 0 to maxSupportedVersion inclusive, the caller should
|
||||||
|
// inspect conn.Version() to determine what version of the handshake
|
||||||
|
// was executed.
|
||||||
|
//
|
||||||
// optionalInit can be the client's initial handshake message as
|
// optionalInit can be the client's initial handshake message as
|
||||||
// returned by ClientDeferred, or nil in which case the initial
|
// returned by ClientDeferred, or nil in which case the initial
|
||||||
// message is read from conn.
|
// message is read from conn.
|
||||||
//
|
//
|
||||||
// The context deadline, if any, covers the entire handshaking
|
// The context deadline, if any, covers the entire handshaking
|
||||||
// process.
|
// process.
|
||||||
func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate, optionalInit []byte) (*Conn, error) {
|
func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate, maxSupportedVersion uint16, optionalInit []byte) (*Conn, error) {
|
||||||
if deadline, ok := ctx.Deadline(); ok {
|
if deadline, ok := ctx.Deadline(); ok {
|
||||||
if err := conn.SetDeadline(deadline); err != nil {
|
if err := conn.SetDeadline(deadline); err != nil {
|
||||||
return nil, fmt.Errorf("setting conn deadline: %w", err)
|
return nil, fmt.Errorf("setting conn deadline: %w", err)
|
||||||
@ -239,9 +245,16 @@ func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate, o
|
|||||||
} else if _, err := io.ReadFull(conn, init.Header()); err != nil {
|
} else if _, err := io.ReadFull(conn, init.Header()); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if init.Version() != protocolVersion {
|
// Currently, these versions exclusively indicate what the upper
|
||||||
return nil, sendErr("unsupported protocol version")
|
// RPC protocol understands, the Noise handshake is exactly the
|
||||||
|
// same in all versions. If that ever changes, this check will
|
||||||
|
// need to become more complex to handle different kinds of
|
||||||
|
// handshake.
|
||||||
|
if init.Version() > maxSupportedVersion {
|
||||||
|
return nil, sendErr("unsupported handshake version")
|
||||||
}
|
}
|
||||||
|
// Just a rename to make it more obvious what the value is
|
||||||
|
clientVersion := init.Version()
|
||||||
if init.Type() != msgTypeInitiation {
|
if init.Type() != msgTypeInitiation {
|
||||||
return nil, sendErr("unexpected handshake message type")
|
return nil, sendErr("unexpected handshake message type")
|
||||||
}
|
}
|
||||||
@ -257,7 +270,7 @@ func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate, o
|
|||||||
|
|
||||||
// prologue. Can only do this once we at least think the client is
|
// prologue. Can only do this once we at least think the client is
|
||||||
// handshaking using a supported version.
|
// handshaking using a supported version.
|
||||||
s.MixHash(protocolVersionPrologue(protocolVersion))
|
s.MixHash(protocolVersionPrologue(clientVersion))
|
||||||
|
|
||||||
// <- s
|
// <- s
|
||||||
// ...
|
// ...
|
||||||
@ -310,7 +323,7 @@ func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate, o
|
|||||||
|
|
||||||
c := &Conn{
|
c := &Conn{
|
||||||
conn: conn,
|
conn: conn,
|
||||||
version: protocolVersion,
|
version: clientVersion,
|
||||||
peer: machineKey,
|
peer: machineKey,
|
||||||
handshakeHash: s.h,
|
handshakeHash: s.h,
|
||||||
tx: txState{
|
tx: txState{
|
||||||
|
@ -26,11 +26,11 @@ func TestHandshake(t *testing.T) {
|
|||||||
)
|
)
|
||||||
go func() {
|
go func() {
|
||||||
var err error
|
var err error
|
||||||
server, err = Server(context.Background(), serverConn, serverKey, nil)
|
server, err = Server(context.Background(), serverConn, serverKey, testProtocolVersion, nil)
|
||||||
serverErr <- err
|
serverErr <- err
|
||||||
}()
|
}()
|
||||||
|
|
||||||
client, err := Client(context.Background(), clientConn, clientKey, serverKey.Public())
|
client, err := Client(context.Background(), clientConn, clientKey, serverKey.Public(), testProtocolVersion)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("client connection failed: %v", err)
|
t.Fatalf("client connection failed: %v", err)
|
||||||
}
|
}
|
||||||
@ -42,8 +42,8 @@ func TestHandshake(t *testing.T) {
|
|||||||
t.Fatal("client and server disagree on handshake hash")
|
t.Fatal("client and server disagree on handshake hash")
|
||||||
}
|
}
|
||||||
|
|
||||||
if client.ProtocolVersion() != int(protocolVersion) {
|
if client.ProtocolVersion() != int(testProtocolVersion) {
|
||||||
t.Fatalf("client reporting wrong protocol version %d, want %d", client.ProtocolVersion(), protocolVersion)
|
t.Fatalf("client reporting wrong protocol version %d, want %d", client.ProtocolVersion(), testProtocolVersion)
|
||||||
}
|
}
|
||||||
if client.ProtocolVersion() != server.ProtocolVersion() {
|
if client.ProtocolVersion() != server.ProtocolVersion() {
|
||||||
t.Fatalf("peers disagree on protocol version, client=%d server=%d", client.ProtocolVersion(), server.ProtocolVersion())
|
t.Fatalf("peers disagree on protocol version, client=%d server=%d", client.ProtocolVersion(), server.ProtocolVersion())
|
||||||
@ -78,11 +78,11 @@ func TestNoReuse(t *testing.T) {
|
|||||||
)
|
)
|
||||||
go func() {
|
go func() {
|
||||||
var err error
|
var err error
|
||||||
server, err = Server(context.Background(), serverConn, serverKey, nil)
|
server, err = Server(context.Background(), serverConn, serverKey, testProtocolVersion, nil)
|
||||||
serverErr <- err
|
serverErr <- err
|
||||||
}()
|
}()
|
||||||
|
|
||||||
client, err := Client(context.Background(), clientConn, clientKey, serverKey.Public())
|
client, err := Client(context.Background(), clientConn, clientKey, serverKey.Public(), testProtocolVersion)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("client connection failed: %v", err)
|
t.Fatalf("client connection failed: %v", err)
|
||||||
}
|
}
|
||||||
@ -172,7 +172,7 @@ func TestTampering(t *testing.T) {
|
|||||||
serverErr = make(chan error, 1)
|
serverErr = make(chan error, 1)
|
||||||
)
|
)
|
||||||
go func() {
|
go func() {
|
||||||
_, err := Server(context.Background(), serverConn, serverKey, nil)
|
_, err := Server(context.Background(), serverConn, serverKey, testProtocolVersion, nil)
|
||||||
// If the server failed, we have to close the Conn to
|
// If the server failed, we have to close the Conn to
|
||||||
// unblock the client.
|
// unblock the client.
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -181,7 +181,7 @@ func TestTampering(t *testing.T) {
|
|||||||
serverErr <- err
|
serverErr <- err
|
||||||
}()
|
}()
|
||||||
|
|
||||||
_, err := Client(context.Background(), clientConn, clientKey, serverKey.Public())
|
_, err := Client(context.Background(), clientConn, clientKey, serverKey.Public(), testProtocolVersion)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("client connection succeeded despite tampering")
|
t.Fatal("client connection succeeded despite tampering")
|
||||||
}
|
}
|
||||||
@ -200,11 +200,11 @@ func TestTampering(t *testing.T) {
|
|||||||
serverErr = make(chan error, 1)
|
serverErr = make(chan error, 1)
|
||||||
)
|
)
|
||||||
go func() {
|
go func() {
|
||||||
_, err := Server(context.Background(), serverConn, serverKey, nil)
|
_, err := Server(context.Background(), serverConn, serverKey, testProtocolVersion, nil)
|
||||||
serverErr <- err
|
serverErr <- err
|
||||||
}()
|
}()
|
||||||
|
|
||||||
_, err := Client(context.Background(), clientConn, clientKey, serverKey.Public())
|
_, err := Client(context.Background(), clientConn, clientKey, serverKey.Public(), testProtocolVersion)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("client connection succeeded despite tampering")
|
t.Fatal("client connection succeeded despite tampering")
|
||||||
}
|
}
|
||||||
@ -225,13 +225,13 @@ func TestTampering(t *testing.T) {
|
|||||||
serverErr = make(chan error, 1)
|
serverErr = make(chan error, 1)
|
||||||
)
|
)
|
||||||
go func() {
|
go func() {
|
||||||
server, err := Server(context.Background(), serverConn, serverKey, nil)
|
server, err := Server(context.Background(), serverConn, serverKey, testProtocolVersion, nil)
|
||||||
serverErr <- err
|
serverErr <- err
|
||||||
_, err = io.WriteString(server, strings.Repeat("a", 14))
|
_, err = io.WriteString(server, strings.Repeat("a", 14))
|
||||||
serverErr <- err
|
serverErr <- err
|
||||||
}()
|
}()
|
||||||
|
|
||||||
client, err := Client(context.Background(), clientConn, clientKey, serverKey.Public())
|
client, err := Client(context.Background(), clientConn, clientKey, serverKey.Public(), testProtocolVersion)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("client handshake failed: %v", err)
|
t.Fatalf("client handshake failed: %v", err)
|
||||||
}
|
}
|
||||||
@ -266,7 +266,7 @@ func TestTampering(t *testing.T) {
|
|||||||
serverErr = make(chan error, 1)
|
serverErr = make(chan error, 1)
|
||||||
)
|
)
|
||||||
go func() {
|
go func() {
|
||||||
server, err := Server(context.Background(), serverConn, serverKey, nil)
|
server, err := Server(context.Background(), serverConn, serverKey, testProtocolVersion, nil)
|
||||||
serverErr <- err
|
serverErr <- err
|
||||||
var bs [100]byte
|
var bs [100]byte
|
||||||
// The server needs a timeout if the tampering is hitting the length header.
|
// The server needs a timeout if the tampering is hitting the length header.
|
||||||
@ -281,7 +281,7 @@ func TestTampering(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
client, err := Client(context.Background(), clientConn, clientKey, serverKey.Public())
|
client, err := Client(context.Background(), clientConn, clientKey, serverKey.Public(), testProtocolVersion)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("client handshake failed: %v", err)
|
t.Fatalf("client handshake failed: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -29,7 +29,7 @@ func TestInteropClient(t *testing.T) {
|
|||||||
)
|
)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
server, err := Server(context.Background(), s2, controlKey, nil)
|
server, err := Server(context.Background(), s2, controlKey, testProtocolVersion, nil)
|
||||||
serverErr <- err
|
serverErr <- err
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
@ -77,7 +77,7 @@ func TestInteropServer(t *testing.T) {
|
|||||||
)
|
)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
client, err := Client(context.Background(), s1, machineKey, controlKey.Public())
|
client, err := Client(context.Background(), s1, machineKey, controlKey.Public(), testProtocolVersion)
|
||||||
clientErr <- err
|
clientErr <- err
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
@ -121,11 +121,11 @@ func noiseExplorerClient(conn net.Conn, controlKey key.MachinePublic, machineKey
|
|||||||
copy(mk.public_key[:], machineKey.Public().UntypedBytes())
|
copy(mk.public_key[:], machineKey.Public().UntypedBytes())
|
||||||
var peerKey [32]byte
|
var peerKey [32]byte
|
||||||
copy(peerKey[:], controlKey.UntypedBytes())
|
copy(peerKey[:], controlKey.UntypedBytes())
|
||||||
session := InitSession(true, protocolVersionPrologue(protocolVersion), mk, peerKey)
|
session := InitSession(true, protocolVersionPrologue(testProtocolVersion), mk, peerKey)
|
||||||
|
|
||||||
_, msg1 := SendMessage(&session, nil)
|
_, msg1 := SendMessage(&session, nil)
|
||||||
var hdr [initiationHeaderLen]byte
|
var hdr [initiationHeaderLen]byte
|
||||||
binary.BigEndian.PutUint16(hdr[:2], protocolVersion)
|
binary.BigEndian.PutUint16(hdr[:2], testProtocolVersion)
|
||||||
hdr[2] = msgTypeInitiation
|
hdr[2] = msgTypeInitiation
|
||||||
binary.BigEndian.PutUint16(hdr[3:5], 96)
|
binary.BigEndian.PutUint16(hdr[3:5], 96)
|
||||||
if _, err := conn.Write(hdr[:]); err != nil {
|
if _, err := conn.Write(hdr[:]); err != nil {
|
||||||
@ -193,7 +193,7 @@ func noiseExplorerServer(conn net.Conn, controlKey key.MachinePrivate, wantMachi
|
|||||||
var mk keypair
|
var mk keypair
|
||||||
copy(mk.private_key[:], controlKey.UntypedBytes())
|
copy(mk.private_key[:], controlKey.UntypedBytes())
|
||||||
copy(mk.public_key[:], controlKey.Public().UntypedBytes())
|
copy(mk.public_key[:], controlKey.Public().UntypedBytes())
|
||||||
session := InitSession(false, protocolVersionPrologue(protocolVersion), mk, [32]byte{})
|
session := InitSession(false, protocolVersionPrologue(testProtocolVersion), mk, [32]byte{})
|
||||||
|
|
||||||
var buf [1024]byte
|
var buf [1024]byte
|
||||||
if _, err := io.ReadFull(conn, buf[:101]); err != nil {
|
if _, err := io.ReadFull(conn, buf[:101]); err != nil {
|
||||||
|
@ -39,9 +39,9 @@
|
|||||||
// 16b: message tag (authenticates the whole message)
|
// 16b: message tag (authenticates the whole message)
|
||||||
type initiationMessage [101]byte
|
type initiationMessage [101]byte
|
||||||
|
|
||||||
func mkInitiationMessage() initiationMessage {
|
func mkInitiationMessage(protocolVersion uint16) initiationMessage {
|
||||||
var ret initiationMessage
|
var ret initiationMessage
|
||||||
binary.BigEndian.PutUint16(ret[:2], uint16(protocolVersion))
|
binary.BigEndian.PutUint16(ret[:2], protocolVersion)
|
||||||
ret[2] = msgTypeInitiation
|
ret[2] = msgTypeInitiation
|
||||||
binary.BigEndian.PutUint16(ret[3:5], uint16(len(ret.Payload())))
|
binary.BigEndian.PutUint16(ret[3:5], uint16(len(ret.Payload())))
|
||||||
return ret
|
return ret
|
||||||
|
@ -8,6 +8,7 @@
|
|||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"math"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
@ -17,6 +18,7 @@
|
|||||||
"golang.org/x/net/http2"
|
"golang.org/x/net/http2"
|
||||||
"tailscale.com/control/controlbase"
|
"tailscale.com/control/controlbase"
|
||||||
"tailscale.com/control/controlhttp"
|
"tailscale.com/control/controlhttp"
|
||||||
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/types/key"
|
"tailscale.com/types/key"
|
||||||
"tailscale.com/util/multierr"
|
"tailscale.com/util/multierr"
|
||||||
)
|
)
|
||||||
@ -146,7 +148,12 @@ func (nc *noiseClient) dial(_, _ string, _ *tls.Config) (net.Conn, error) {
|
|||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
conn, err := controlhttp.Dial(ctx, nc.serverHost, nc.privKey, nc.serverPubKey)
|
if tailcfg.CurrentCapabilityVersion > math.MaxUint16 {
|
||||||
|
// Panic, because a test should have started failing several
|
||||||
|
// thousand version numbers before getting to this point.
|
||||||
|
panic("capability version is too high to fit in the wire protocol")
|
||||||
|
}
|
||||||
|
conn, err := controlhttp.Dial(ctx, nc.serverHost, nc.privKey, nc.serverPubKey, uint16(tailcfg.CurrentCapabilityVersion))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
28
control/controlclient/noise_test.go
Normal file
28
control/controlclient/noise_test.go
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package controlclient
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"tailscale.com/tailcfg"
|
||||||
|
)
|
||||||
|
|
||||||
|
// maxAllowedNoiseVersion is the highest we expect the Tailscale
|
||||||
|
// capability version to ever get. It's a value close to 2^16, but
|
||||||
|
// with enough leeway that we get a very early warning that it's time
|
||||||
|
// to rework the wire protocol to allow larger versions, while still
|
||||||
|
// giving us headroom to bump this test and fix the build.
|
||||||
|
//
|
||||||
|
// Code elsewhere in the client will panic() if the tailcfg capability
|
||||||
|
// version exceeds 16 bits, so take a failure of this test seriously.
|
||||||
|
const maxAllowedNoiseVersion = math.MaxUint16 - 5000
|
||||||
|
|
||||||
|
func TestNoiseVersion(t *testing.T) {
|
||||||
|
if tailcfg.CurrentCapabilityVersion > maxAllowedNoiseVersion {
|
||||||
|
t.Fatalf("tailcfg.CurrentCapabilityVersion is %d, want <=%d", tailcfg.CurrentCapabilityVersion, maxAllowedNoiseVersion)
|
||||||
|
}
|
||||||
|
}
|
@ -65,7 +65,7 @@
|
|||||||
//
|
//
|
||||||
// The provided ctx is only used for the initial connection, until
|
// The provided ctx is only used for the initial connection, until
|
||||||
// Dial returns. It does not affect the connection once established.
|
// Dial returns. It does not affect the connection once established.
|
||||||
func Dial(ctx context.Context, addr string, machineKey key.MachinePrivate, controlKey key.MachinePublic) (*controlbase.Conn, error) {
|
func Dial(ctx context.Context, addr string, machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (*controlbase.Conn, error) {
|
||||||
host, port, err := net.SplitHostPort(addr)
|
host, port, err := net.SplitHostPort(addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -77,6 +77,7 @@ func Dial(ctx context.Context, addr string, machineKey key.MachinePrivate, contr
|
|||||||
httpsPort: "443",
|
httpsPort: "443",
|
||||||
machineKey: machineKey,
|
machineKey: machineKey,
|
||||||
controlKey: controlKey,
|
controlKey: controlKey,
|
||||||
|
version: protocolVersion,
|
||||||
proxyFunc: tshttpproxy.ProxyFromEnvironment,
|
proxyFunc: tshttpproxy.ProxyFromEnvironment,
|
||||||
}
|
}
|
||||||
return a.dial()
|
return a.dial()
|
||||||
@ -89,6 +90,7 @@ type dialParams struct {
|
|||||||
httpsPort string
|
httpsPort string
|
||||||
machineKey key.MachinePrivate
|
machineKey key.MachinePrivate
|
||||||
controlKey key.MachinePublic
|
controlKey key.MachinePublic
|
||||||
|
version uint16
|
||||||
proxyFunc func(*http.Request) (*url.URL, error) // or nil
|
proxyFunc func(*http.Request) (*url.URL, error) // or nil
|
||||||
|
|
||||||
// For tests only
|
// For tests only
|
||||||
@ -96,7 +98,7 @@ type dialParams struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *dialParams) dial() (*controlbase.Conn, error) {
|
func (a *dialParams) dial() (*controlbase.Conn, error) {
|
||||||
init, cont, err := controlbase.ClientDeferred(a.machineKey, a.controlKey)
|
init, cont, err := controlbase.ClientDeferred(a.machineKey, a.controlKey, a.version)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -120,7 +122,7 @@ func (a *dialParams) dial() (*controlbase.Conn, error) {
|
|||||||
// being difficult and see if we can get through over HTTPS.
|
// being difficult and see if we can get through over HTTPS.
|
||||||
u.Scheme = "https"
|
u.Scheme = "https"
|
||||||
u.Host = net.JoinHostPort(a.host, a.httpsPort)
|
u.Host = net.JoinHostPort(a.host, a.httpsPort)
|
||||||
init, cont, err = controlbase.ClientDeferred(a.machineKey, a.controlKey)
|
init, cont, err = controlbase.ClientDeferred(a.machineKey, a.controlKey, a.version)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -104,9 +104,10 @@ func TestControlHTTP(t *testing.T) {
|
|||||||
func testControlHTTP(t *testing.T, proxy proxy) {
|
func testControlHTTP(t *testing.T, proxy proxy) {
|
||||||
client, server := key.NewMachine(), key.NewMachine()
|
client, server := key.NewMachine(), key.NewMachine()
|
||||||
|
|
||||||
|
const testProtocolVersion = 1
|
||||||
sch := make(chan serverResult, 1)
|
sch := make(chan serverResult, 1)
|
||||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
conn, err := AcceptHTTP(context.Background(), w, r, server)
|
conn, err := AcceptHTTP(context.Background(), w, r, server, testProtocolVersion)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
}
|
}
|
||||||
@ -152,6 +153,7 @@ func testControlHTTP(t *testing.T, proxy proxy) {
|
|||||||
httpsPort: strconv.Itoa(httpsLn.Addr().(*net.TCPAddr).Port),
|
httpsPort: strconv.Itoa(httpsLn.Addr().(*net.TCPAddr).Port),
|
||||||
machineKey: client,
|
machineKey: client,
|
||||||
controlKey: server.Public(),
|
controlKey: server.Public(),
|
||||||
|
version: testProtocolVersion,
|
||||||
insecureTLS: true,
|
insecureTLS: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -21,7 +21,7 @@
|
|||||||
//
|
//
|
||||||
// AcceptHTTP always writes an HTTP response to w. The caller must not
|
// AcceptHTTP always writes an HTTP response to w. The caller must not
|
||||||
// attempt their own response after calling AcceptHTTP.
|
// attempt their own response after calling AcceptHTTP.
|
||||||
func AcceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, private key.MachinePrivate) (*controlbase.Conn, error) {
|
func AcceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, private key.MachinePrivate, maxSupportedVersion uint16) (*controlbase.Conn, error) {
|
||||||
next := r.Header.Get("Upgrade")
|
next := r.Header.Get("Upgrade")
|
||||||
if next == "" {
|
if next == "" {
|
||||||
http.Error(w, "missing next protocol", http.StatusBadRequest)
|
http.Error(w, "missing next protocol", http.StatusBadRequest)
|
||||||
@ -63,7 +63,7 @@ func AcceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, pri
|
|||||||
}
|
}
|
||||||
conn = netutil.NewDrainBufConn(conn, brw.Reader)
|
conn = netutil.NewDrainBufConn(conn, brw.Reader)
|
||||||
|
|
||||||
nc, err := controlbase.Server(ctx, conn, private, init)
|
nc, err := controlbase.Server(ctx, conn, private, maxSupportedVersion, init)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.Close()
|
conn.Close()
|
||||||
return nil, fmt.Errorf("noise handshake failed: %w", err)
|
return nil, fmt.Errorf("noise handshake failed: %w", err)
|
||||||
|
Loading…
Reference in New Issue
Block a user