// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package ssh

import (
	"bytes"
	"errors"
	"fmt"
	"strings"
	"testing"
)

func doClientServerAuth(t *testing.T, serverConfig *ServerConfig, clientConfig *ClientConfig) ([]error, error) {
	c1, c2, err := netPipe()
	if err != nil {
		t.Fatalf("netPipe: %v", err)
	}
	defer c1.Close()
	defer c2.Close()

	var serverAuthErrors []error

	serverConfig.AddHostKey(testSigners["rsa"])
	serverConfig.AuthLogCallback = func(conn ConnMetadata, method string, err error) {
		serverAuthErrors = append(serverAuthErrors, err)
	}
	go newServer(c1, serverConfig)
	c, _, _, err := NewClientConn(c2, "", clientConfig)
	if err == nil {
		c.Close()
	}
	return serverAuthErrors, err
}

func TestMultiStepAuth(t *testing.T) {
	// This user can login with password, public key or public key + password.
	username := "testuser"
	// This user can login with public key + password only.
	usernameSecondFactor := "testuser_second_factor"
	errPwdAuthFailed := errors.New("password auth failed")
	errWrongSequence := errors.New("wrong sequence")

	serverConfig := &ServerConfig{
		PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
			if conn.User() == usernameSecondFactor {
				return nil, errWrongSequence
			}
			if conn.User() == username && string(password) == clientPassword {
				return nil, nil
			}
			return nil, errPwdAuthFailed
		},
		PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) {
			if bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) {
				if conn.User() == usernameSecondFactor {
					return nil, &PartialSuccessError{
						Next: ServerAuthCallbacks{
							PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
								if string(password) == clientPassword {
									return nil, nil
								}
								return nil, errPwdAuthFailed
							},
						},
					}
				}
				return nil, nil
			}
			return nil, fmt.Errorf("pubkey for %q not acceptable", conn.User())
		},
	}

	clientConfig := &ClientConfig{
		User: usernameSecondFactor,
		Auth: []AuthMethod{
			PublicKeys(testSigners["rsa"]),
			Password(clientPassword),
		},
		HostKeyCallback: InsecureIgnoreHostKey(),
	}

	serverAuthErrors, err := doClientServerAuth(t, serverConfig, clientConfig)
	if err != nil {
		t.Fatalf("client login error: %s", err)
	}

	// The error sequence is:
	// - no auth passed yet
	// - partial success
	// - nil
	if len(serverAuthErrors) != 3 {
		t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
	}
	if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok {
		t.Fatalf("expected partial success error, got: %v", serverAuthErrors[1])
	}
	// Now test a wrong sequence.
	clientConfig.Auth = []AuthMethod{
		Password(clientPassword),
		PublicKeys(testSigners["rsa"]),
	}

	serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
	if err == nil {
		t.Fatal("client login with wrong sequence must fail")
	}
	// The error sequence is:
	// - no auth passed yet
	// - wrong sequence
	// - partial success
	if len(serverAuthErrors) != 3 {
		t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
	}
	if serverAuthErrors[1] != errWrongSequence {
		t.Fatal("server not returned wrong sequence")
	}
	if _, ok := serverAuthErrors[2].(*PartialSuccessError); !ok {
		t.Fatalf("expected partial success error, got: %v", serverAuthErrors[2])
	}
	// Now test using a correct sequence but a wrong password before the right
	// one.
	n := 0
	passwords := []string{"WRONG", "WRONG", clientPassword}
	clientConfig.Auth = []AuthMethod{
		PublicKeys(testSigners["rsa"]),
		RetryableAuthMethod(PasswordCallback(func() (string, error) {
			p := passwords[n]
			n++
			return p, nil
		}), 3),
	}

	serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
	if err != nil {
		t.Fatalf("client login error: %s", err)
	}
	// The error sequence is:
	// - no auth passed yet
	// - partial success
	// - wrong password
	// - wrong password
	// - nil
	if len(serverAuthErrors) != 5 {
		t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
	}
	if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok {
		t.Fatal("server not returned partial success")
	}
	if serverAuthErrors[2] != errPwdAuthFailed {
		t.Fatal("server not returned password authentication failed")
	}
	if serverAuthErrors[3] != errPwdAuthFailed {
		t.Fatal("server not returned password authentication failed")
	}
	// Only password authentication should fail.
	clientConfig.Auth = []AuthMethod{
		Password(clientPassword),
	}

	serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
	if err == nil {
		t.Fatal("client login with password only must fail")
	}
	// The error sequence is:
	// - no auth passed yet
	// - wrong sequence
	if len(serverAuthErrors) != 2 {
		t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
	}
	if serverAuthErrors[1] != errWrongSequence {
		t.Fatal("server not returned wrong sequence")
	}

	// Only public key authentication should fail.
	clientConfig.Auth = []AuthMethod{
		PublicKeys(testSigners["rsa"]),
	}

	serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
	if err == nil {
		t.Fatal("client login with public key only must fail")
	}
	// The error sequence is:
	// - no auth passed yet
	// - partial success
	if len(serverAuthErrors) != 2 {
		t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
	}
	if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok {
		t.Fatal("server not returned partial success")
	}

	// Public key and wrong password.
	clientConfig.Auth = []AuthMethod{
		PublicKeys(testSigners["rsa"]),
		Password("WRONG"),
	}

	serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
	if err == nil {
		t.Fatal("client login with wrong password after public key must fail")
	}
	// The error sequence is:
	// - no auth passed yet
	// - partial success
	// - password auth failed
	if len(serverAuthErrors) != 3 {
		t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
	}
	if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok {
		t.Fatal("server not returned partial success")
	}
	if serverAuthErrors[2] != errPwdAuthFailed {
		t.Fatal("server not returned password authentication failed")
	}

	// Public key, public key again and then correct password. Public key
	// authentication is attempted only once because the partial success error
	// returns only "password" as the allowed authentication method.
	clientConfig.Auth = []AuthMethod{
		PublicKeys(testSigners["rsa"]),
		PublicKeys(testSigners["rsa"]),
		Password(clientPassword),
	}

	serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
	if err != nil {
		t.Fatalf("client login error: %s", err)
	}
	// The error sequence is:
	// - no auth passed yet
	// - partial success
	// - nil
	if len(serverAuthErrors) != 3 {
		t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
	}
	if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok {
		t.Fatal("server not returned partial success")
	}

	// The unrestricted username can do anything
	clientConfig = &ClientConfig{
		User: username,
		Auth: []AuthMethod{
			PublicKeys(testSigners["rsa"]),
			Password(clientPassword),
		},
		HostKeyCallback: InsecureIgnoreHostKey(),
	}

	_, err = doClientServerAuth(t, serverConfig, clientConfig)
	if err != nil {
		t.Fatalf("unrestricted client login error: %s", err)
	}

	clientConfig = &ClientConfig{
		User: username,
		Auth: []AuthMethod{
			PublicKeys(testSigners["rsa"]),
		},
		HostKeyCallback: InsecureIgnoreHostKey(),
	}

	_, err = doClientServerAuth(t, serverConfig, clientConfig)
	if err != nil {
		t.Fatalf("unrestricted client login error: %s", err)
	}

	clientConfig = &ClientConfig{
		User: username,
		Auth: []AuthMethod{
			Password(clientPassword),
		},
		HostKeyCallback: InsecureIgnoreHostKey(),
	}

	_, err = doClientServerAuth(t, serverConfig, clientConfig)
	if err != nil {
		t.Fatalf("unrestricted client login error: %s", err)
	}
}

func TestDynamicAuthCallbacks(t *testing.T) {
	user1 := "user1"
	user2 := "user2"
	errInvalidCredentials := errors.New("invalid credentials")

	serverConfig := &ServerConfig{
		NoClientAuth: true,
		NoClientAuthCallback: func(conn ConnMetadata) (*Permissions, error) {
			switch conn.User() {
			case user1:
				return nil, &PartialSuccessError{
					Next: ServerAuthCallbacks{
						PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
							if conn.User() == user1 && string(password) == clientPassword {
								return nil, nil
							}
							return nil, errInvalidCredentials
						},
					},
				}
			case user2:
				return nil, &PartialSuccessError{
					Next: ServerAuthCallbacks{
						PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) {
							if bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) {
								if conn.User() == user2 {
									return nil, nil
								}
							}
							return nil, errInvalidCredentials
						},
					},
				}
			default:
				return nil, errInvalidCredentials
			}
		},
	}

	clientConfig := &ClientConfig{
		User: user1,
		Auth: []AuthMethod{
			Password(clientPassword),
		},
		HostKeyCallback: InsecureIgnoreHostKey(),
	}

	serverAuthErrors, err := doClientServerAuth(t, serverConfig, clientConfig)
	if err != nil {
		t.Fatalf("client login error: %s", err)
	}
	// The error sequence is:
	// - partial success
	// - nil
	if len(serverAuthErrors) != 2 {
		t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
	}
	if _, ok := serverAuthErrors[0].(*PartialSuccessError); !ok {
		t.Fatal("server not returned partial success")
	}

	clientConfig = &ClientConfig{
		User: user2,
		Auth: []AuthMethod{
			PublicKeys(testSigners["rsa"]),
		},
		HostKeyCallback: InsecureIgnoreHostKey(),
	}

	serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
	if err != nil {
		t.Fatalf("client login error: %s", err)
	}
	// The error sequence is:
	// - partial success
	// - nil
	if len(serverAuthErrors) != 2 {
		t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
	}
	if _, ok := serverAuthErrors[0].(*PartialSuccessError); !ok {
		t.Fatal("server not returned partial success")
	}

	// user1 cannot login with public key
	clientConfig = &ClientConfig{
		User: user1,
		Auth: []AuthMethod{
			PublicKeys(testSigners["rsa"]),
		},
		HostKeyCallback: InsecureIgnoreHostKey(),
	}

	serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
	if err == nil {
		t.Fatal("user1 login with public key must fail")
	}
	if !strings.Contains(err.Error(), "no supported methods remain") {
		t.Errorf("got %v, expected 'no supported methods remain'", err)
	}
	if len(serverAuthErrors) != 1 {
		t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
	}
	if _, ok := serverAuthErrors[0].(*PartialSuccessError); !ok {
		t.Fatal("server not returned partial success")
	}
	// user2 cannot login with password
	clientConfig = &ClientConfig{
		User: user2,
		Auth: []AuthMethod{
			Password(clientPassword),
		},
		HostKeyCallback: InsecureIgnoreHostKey(),
	}

	serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
	if err == nil {
		t.Fatal("user2 login with password must fail")
	}
	if !strings.Contains(err.Error(), "no supported methods remain") {
		t.Errorf("got %v, expected 'no supported methods remain'", err)
	}
	if len(serverAuthErrors) != 1 {
		t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
	}
	if _, ok := serverAuthErrors[0].(*PartialSuccessError); !ok {
		t.Fatal("server not returned partial success")
	}
}