mirror of
https://github.com/tailscale/tailscale.git
synced 2025-04-06 16:25:50 +00:00
413 lines
12 KiB
Go
413 lines
12 KiB
Go
![]() |
// Copyright 2024 The Go Authors. All rights reserved.
|
||
|
// Use of this source code is governed by a BSD-style
|
||
|
// license that can be found in the LICENSE file.
|
||
|
|
||
|
package ssh
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"strings"
|
||
|
"testing"
|
||
|
)
|
||
|
|
||
|
func doClientServerAuth(t *testing.T, serverConfig *ServerConfig, clientConfig *ClientConfig) ([]error, error) {
|
||
|
c1, c2, err := netPipe()
|
||
|
if err != nil {
|
||
|
t.Fatalf("netPipe: %v", err)
|
||
|
}
|
||
|
defer c1.Close()
|
||
|
defer c2.Close()
|
||
|
|
||
|
var serverAuthErrors []error
|
||
|
|
||
|
serverConfig.AddHostKey(testSigners["rsa"])
|
||
|
serverConfig.AuthLogCallback = func(conn ConnMetadata, method string, err error) {
|
||
|
serverAuthErrors = append(serverAuthErrors, err)
|
||
|
}
|
||
|
go newServer(c1, serverConfig)
|
||
|
c, _, _, err := NewClientConn(c2, "", clientConfig)
|
||
|
if err == nil {
|
||
|
c.Close()
|
||
|
}
|
||
|
return serverAuthErrors, err
|
||
|
}
|
||
|
|
||
|
func TestMultiStepAuth(t *testing.T) {
|
||
|
// This user can login with password, public key or public key + password.
|
||
|
username := "testuser"
|
||
|
// This user can login with public key + password only.
|
||
|
usernameSecondFactor := "testuser_second_factor"
|
||
|
errPwdAuthFailed := errors.New("password auth failed")
|
||
|
errWrongSequence := errors.New("wrong sequence")
|
||
|
|
||
|
serverConfig := &ServerConfig{
|
||
|
PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
|
||
|
if conn.User() == usernameSecondFactor {
|
||
|
return nil, errWrongSequence
|
||
|
}
|
||
|
if conn.User() == username && string(password) == clientPassword {
|
||
|
return nil, nil
|
||
|
}
|
||
|
return nil, errPwdAuthFailed
|
||
|
},
|
||
|
PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) {
|
||
|
if bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) {
|
||
|
if conn.User() == usernameSecondFactor {
|
||
|
return nil, &PartialSuccessError{
|
||
|
Next: ServerAuthCallbacks{
|
||
|
PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
|
||
|
if string(password) == clientPassword {
|
||
|
return nil, nil
|
||
|
}
|
||
|
return nil, errPwdAuthFailed
|
||
|
},
|
||
|
},
|
||
|
}
|
||
|
}
|
||
|
return nil, nil
|
||
|
}
|
||
|
return nil, fmt.Errorf("pubkey for %q not acceptable", conn.User())
|
||
|
},
|
||
|
}
|
||
|
|
||
|
clientConfig := &ClientConfig{
|
||
|
User: usernameSecondFactor,
|
||
|
Auth: []AuthMethod{
|
||
|
PublicKeys(testSigners["rsa"]),
|
||
|
Password(clientPassword),
|
||
|
},
|
||
|
HostKeyCallback: InsecureIgnoreHostKey(),
|
||
|
}
|
||
|
|
||
|
serverAuthErrors, err := doClientServerAuth(t, serverConfig, clientConfig)
|
||
|
if err != nil {
|
||
|
t.Fatalf("client login error: %s", err)
|
||
|
}
|
||
|
|
||
|
// The error sequence is:
|
||
|
// - no auth passed yet
|
||
|
// - partial success
|
||
|
// - nil
|
||
|
if len(serverAuthErrors) != 3 {
|
||
|
t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
|
||
|
}
|
||
|
if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok {
|
||
|
t.Fatalf("expected partial success error, got: %v", serverAuthErrors[1])
|
||
|
}
|
||
|
// Now test a wrong sequence.
|
||
|
clientConfig.Auth = []AuthMethod{
|
||
|
Password(clientPassword),
|
||
|
PublicKeys(testSigners["rsa"]),
|
||
|
}
|
||
|
|
||
|
serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
|
||
|
if err == nil {
|
||
|
t.Fatal("client login with wrong sequence must fail")
|
||
|
}
|
||
|
// The error sequence is:
|
||
|
// - no auth passed yet
|
||
|
// - wrong sequence
|
||
|
// - partial success
|
||
|
if len(serverAuthErrors) != 3 {
|
||
|
t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
|
||
|
}
|
||
|
if serverAuthErrors[1] != errWrongSequence {
|
||
|
t.Fatal("server not returned wrong sequence")
|
||
|
}
|
||
|
if _, ok := serverAuthErrors[2].(*PartialSuccessError); !ok {
|
||
|
t.Fatalf("expected partial success error, got: %v", serverAuthErrors[2])
|
||
|
}
|
||
|
// Now test using a correct sequence but a wrong password before the right
|
||
|
// one.
|
||
|
n := 0
|
||
|
passwords := []string{"WRONG", "WRONG", clientPassword}
|
||
|
clientConfig.Auth = []AuthMethod{
|
||
|
PublicKeys(testSigners["rsa"]),
|
||
|
RetryableAuthMethod(PasswordCallback(func() (string, error) {
|
||
|
p := passwords[n]
|
||
|
n++
|
||
|
return p, nil
|
||
|
}), 3),
|
||
|
}
|
||
|
|
||
|
serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
|
||
|
if err != nil {
|
||
|
t.Fatalf("client login error: %s", err)
|
||
|
}
|
||
|
// The error sequence is:
|
||
|
// - no auth passed yet
|
||
|
// - partial success
|
||
|
// - wrong password
|
||
|
// - wrong password
|
||
|
// - nil
|
||
|
if len(serverAuthErrors) != 5 {
|
||
|
t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
|
||
|
}
|
||
|
if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok {
|
||
|
t.Fatal("server not returned partial success")
|
||
|
}
|
||
|
if serverAuthErrors[2] != errPwdAuthFailed {
|
||
|
t.Fatal("server not returned password authentication failed")
|
||
|
}
|
||
|
if serverAuthErrors[3] != errPwdAuthFailed {
|
||
|
t.Fatal("server not returned password authentication failed")
|
||
|
}
|
||
|
// Only password authentication should fail.
|
||
|
clientConfig.Auth = []AuthMethod{
|
||
|
Password(clientPassword),
|
||
|
}
|
||
|
|
||
|
serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
|
||
|
if err == nil {
|
||
|
t.Fatal("client login with password only must fail")
|
||
|
}
|
||
|
// The error sequence is:
|
||
|
// - no auth passed yet
|
||
|
// - wrong sequence
|
||
|
if len(serverAuthErrors) != 2 {
|
||
|
t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
|
||
|
}
|
||
|
if serverAuthErrors[1] != errWrongSequence {
|
||
|
t.Fatal("server not returned wrong sequence")
|
||
|
}
|
||
|
|
||
|
// Only public key authentication should fail.
|
||
|
clientConfig.Auth = []AuthMethod{
|
||
|
PublicKeys(testSigners["rsa"]),
|
||
|
}
|
||
|
|
||
|
serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
|
||
|
if err == nil {
|
||
|
t.Fatal("client login with public key only must fail")
|
||
|
}
|
||
|
// The error sequence is:
|
||
|
// - no auth passed yet
|
||
|
// - partial success
|
||
|
if len(serverAuthErrors) != 2 {
|
||
|
t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
|
||
|
}
|
||
|
if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok {
|
||
|
t.Fatal("server not returned partial success")
|
||
|
}
|
||
|
|
||
|
// Public key and wrong password.
|
||
|
clientConfig.Auth = []AuthMethod{
|
||
|
PublicKeys(testSigners["rsa"]),
|
||
|
Password("WRONG"),
|
||
|
}
|
||
|
|
||
|
serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
|
||
|
if err == nil {
|
||
|
t.Fatal("client login with wrong password after public key must fail")
|
||
|
}
|
||
|
// The error sequence is:
|
||
|
// - no auth passed yet
|
||
|
// - partial success
|
||
|
// - password auth failed
|
||
|
if len(serverAuthErrors) != 3 {
|
||
|
t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
|
||
|
}
|
||
|
if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok {
|
||
|
t.Fatal("server not returned partial success")
|
||
|
}
|
||
|
if serverAuthErrors[2] != errPwdAuthFailed {
|
||
|
t.Fatal("server not returned password authentication failed")
|
||
|
}
|
||
|
|
||
|
// Public key, public key again and then correct password. Public key
|
||
|
// authentication is attempted only once because the partial success error
|
||
|
// returns only "password" as the allowed authentication method.
|
||
|
clientConfig.Auth = []AuthMethod{
|
||
|
PublicKeys(testSigners["rsa"]),
|
||
|
PublicKeys(testSigners["rsa"]),
|
||
|
Password(clientPassword),
|
||
|
}
|
||
|
|
||
|
serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
|
||
|
if err != nil {
|
||
|
t.Fatalf("client login error: %s", err)
|
||
|
}
|
||
|
// The error sequence is:
|
||
|
// - no auth passed yet
|
||
|
// - partial success
|
||
|
// - nil
|
||
|
if len(serverAuthErrors) != 3 {
|
||
|
t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
|
||
|
}
|
||
|
if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok {
|
||
|
t.Fatal("server not returned partial success")
|
||
|
}
|
||
|
|
||
|
// The unrestricted username can do anything
|
||
|
clientConfig = &ClientConfig{
|
||
|
User: username,
|
||
|
Auth: []AuthMethod{
|
||
|
PublicKeys(testSigners["rsa"]),
|
||
|
Password(clientPassword),
|
||
|
},
|
||
|
HostKeyCallback: InsecureIgnoreHostKey(),
|
||
|
}
|
||
|
|
||
|
_, err = doClientServerAuth(t, serverConfig, clientConfig)
|
||
|
if err != nil {
|
||
|
t.Fatalf("unrestricted client login error: %s", err)
|
||
|
}
|
||
|
|
||
|
clientConfig = &ClientConfig{
|
||
|
User: username,
|
||
|
Auth: []AuthMethod{
|
||
|
PublicKeys(testSigners["rsa"]),
|
||
|
},
|
||
|
HostKeyCallback: InsecureIgnoreHostKey(),
|
||
|
}
|
||
|
|
||
|
_, err = doClientServerAuth(t, serverConfig, clientConfig)
|
||
|
if err != nil {
|
||
|
t.Fatalf("unrestricted client login error: %s", err)
|
||
|
}
|
||
|
|
||
|
clientConfig = &ClientConfig{
|
||
|
User: username,
|
||
|
Auth: []AuthMethod{
|
||
|
Password(clientPassword),
|
||
|
},
|
||
|
HostKeyCallback: InsecureIgnoreHostKey(),
|
||
|
}
|
||
|
|
||
|
_, err = doClientServerAuth(t, serverConfig, clientConfig)
|
||
|
if err != nil {
|
||
|
t.Fatalf("unrestricted client login error: %s", err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestDynamicAuthCallbacks(t *testing.T) {
|
||
|
user1 := "user1"
|
||
|
user2 := "user2"
|
||
|
errInvalidCredentials := errors.New("invalid credentials")
|
||
|
|
||
|
serverConfig := &ServerConfig{
|
||
|
NoClientAuth: true,
|
||
|
NoClientAuthCallback: func(conn ConnMetadata) (*Permissions, error) {
|
||
|
switch conn.User() {
|
||
|
case user1:
|
||
|
return nil, &PartialSuccessError{
|
||
|
Next: ServerAuthCallbacks{
|
||
|
PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
|
||
|
if conn.User() == user1 && string(password) == clientPassword {
|
||
|
return nil, nil
|
||
|
}
|
||
|
return nil, errInvalidCredentials
|
||
|
},
|
||
|
},
|
||
|
}
|
||
|
case user2:
|
||
|
return nil, &PartialSuccessError{
|
||
|
Next: ServerAuthCallbacks{
|
||
|
PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) {
|
||
|
if bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) {
|
||
|
if conn.User() == user2 {
|
||
|
return nil, nil
|
||
|
}
|
||
|
}
|
||
|
return nil, errInvalidCredentials
|
||
|
},
|
||
|
},
|
||
|
}
|
||
|
default:
|
||
|
return nil, errInvalidCredentials
|
||
|
}
|
||
|
},
|
||
|
}
|
||
|
|
||
|
clientConfig := &ClientConfig{
|
||
|
User: user1,
|
||
|
Auth: []AuthMethod{
|
||
|
Password(clientPassword),
|
||
|
},
|
||
|
HostKeyCallback: InsecureIgnoreHostKey(),
|
||
|
}
|
||
|
|
||
|
serverAuthErrors, err := doClientServerAuth(t, serverConfig, clientConfig)
|
||
|
if err != nil {
|
||
|
t.Fatalf("client login error: %s", err)
|
||
|
}
|
||
|
// The error sequence is:
|
||
|
// - partial success
|
||
|
// - nil
|
||
|
if len(serverAuthErrors) != 2 {
|
||
|
t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
|
||
|
}
|
||
|
if _, ok := serverAuthErrors[0].(*PartialSuccessError); !ok {
|
||
|
t.Fatal("server not returned partial success")
|
||
|
}
|
||
|
|
||
|
clientConfig = &ClientConfig{
|
||
|
User: user2,
|
||
|
Auth: []AuthMethod{
|
||
|
PublicKeys(testSigners["rsa"]),
|
||
|
},
|
||
|
HostKeyCallback: InsecureIgnoreHostKey(),
|
||
|
}
|
||
|
|
||
|
serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
|
||
|
if err != nil {
|
||
|
t.Fatalf("client login error: %s", err)
|
||
|
}
|
||
|
// The error sequence is:
|
||
|
// - partial success
|
||
|
// - nil
|
||
|
if len(serverAuthErrors) != 2 {
|
||
|
t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
|
||
|
}
|
||
|
if _, ok := serverAuthErrors[0].(*PartialSuccessError); !ok {
|
||
|
t.Fatal("server not returned partial success")
|
||
|
}
|
||
|
|
||
|
// user1 cannot login with public key
|
||
|
clientConfig = &ClientConfig{
|
||
|
User: user1,
|
||
|
Auth: []AuthMethod{
|
||
|
PublicKeys(testSigners["rsa"]),
|
||
|
},
|
||
|
HostKeyCallback: InsecureIgnoreHostKey(),
|
||
|
}
|
||
|
|
||
|
serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
|
||
|
if err == nil {
|
||
|
t.Fatal("user1 login with public key must fail")
|
||
|
}
|
||
|
if !strings.Contains(err.Error(), "no supported methods remain") {
|
||
|
t.Errorf("got %v, expected 'no supported methods remain'", err)
|
||
|
}
|
||
|
if len(serverAuthErrors) != 1 {
|
||
|
t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
|
||
|
}
|
||
|
if _, ok := serverAuthErrors[0].(*PartialSuccessError); !ok {
|
||
|
t.Fatal("server not returned partial success")
|
||
|
}
|
||
|
// user2 cannot login with password
|
||
|
clientConfig = &ClientConfig{
|
||
|
User: user2,
|
||
|
Auth: []AuthMethod{
|
||
|
Password(clientPassword),
|
||
|
},
|
||
|
HostKeyCallback: InsecureIgnoreHostKey(),
|
||
|
}
|
||
|
|
||
|
serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
|
||
|
if err == nil {
|
||
|
t.Fatal("user2 login with password must fail")
|
||
|
}
|
||
|
if !strings.Contains(err.Error(), "no supported methods remain") {
|
||
|
t.Errorf("got %v, expected 'no supported methods remain'", err)
|
||
|
}
|
||
|
if len(serverAuthErrors) != 1 {
|
||
|
t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
|
||
|
}
|
||
|
if _, ok := serverAuthErrors[0].(*PartialSuccessError); !ok {
|
||
|
t.Fatal("server not returned partial success")
|
||
|
}
|
||
|
}
|