// Copyright 2024 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package ssh import ( "bytes" "errors" "fmt" "strings" "testing" ) func doClientServerAuth(t *testing.T, serverConfig *ServerConfig, clientConfig *ClientConfig) ([]error, error) { c1, c2, err := netPipe() if err != nil { t.Fatalf("netPipe: %v", err) } defer c1.Close() defer c2.Close() var serverAuthErrors []error serverConfig.AddHostKey(testSigners["rsa"]) serverConfig.AuthLogCallback = func(conn ConnMetadata, method string, err error) { serverAuthErrors = append(serverAuthErrors, err) } go newServer(c1, serverConfig) c, _, _, err := NewClientConn(c2, "", clientConfig) if err == nil { c.Close() } return serverAuthErrors, err } func TestMultiStepAuth(t *testing.T) { // This user can login with password, public key or public key + password. username := "testuser" // This user can login with public key + password only. usernameSecondFactor := "testuser_second_factor" errPwdAuthFailed := errors.New("password auth failed") errWrongSequence := errors.New("wrong sequence") serverConfig := &ServerConfig{ PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) { if conn.User() == usernameSecondFactor { return nil, errWrongSequence } if conn.User() == username && string(password) == clientPassword { return nil, nil } return nil, errPwdAuthFailed }, PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) { if bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) { if conn.User() == usernameSecondFactor { return nil, &PartialSuccessError{ Next: ServerAuthCallbacks{ PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) { if string(password) == clientPassword { return nil, nil } return nil, errPwdAuthFailed }, }, } } return nil, nil } return nil, fmt.Errorf("pubkey for %q not acceptable", conn.User()) }, } clientConfig := &ClientConfig{ User: usernameSecondFactor, Auth: []AuthMethod{ PublicKeys(testSigners["rsa"]), Password(clientPassword), }, HostKeyCallback: InsecureIgnoreHostKey(), } serverAuthErrors, err := doClientServerAuth(t, serverConfig, clientConfig) if err != nil { t.Fatalf("client login error: %s", err) } // The error sequence is: // - no auth passed yet // - partial success // - nil if len(serverAuthErrors) != 3 { t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors) } if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok { t.Fatalf("expected partial success error, got: %v", serverAuthErrors[1]) } // Now test a wrong sequence. clientConfig.Auth = []AuthMethod{ Password(clientPassword), PublicKeys(testSigners["rsa"]), } serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig) if err == nil { t.Fatal("client login with wrong sequence must fail") } // The error sequence is: // - no auth passed yet // - wrong sequence // - partial success if len(serverAuthErrors) != 3 { t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors) } if serverAuthErrors[1] != errWrongSequence { t.Fatal("server not returned wrong sequence") } if _, ok := serverAuthErrors[2].(*PartialSuccessError); !ok { t.Fatalf("expected partial success error, got: %v", serverAuthErrors[2]) } // Now test using a correct sequence but a wrong password before the right // one. n := 0 passwords := []string{"WRONG", "WRONG", clientPassword} clientConfig.Auth = []AuthMethod{ PublicKeys(testSigners["rsa"]), RetryableAuthMethod(PasswordCallback(func() (string, error) { p := passwords[n] n++ return p, nil }), 3), } serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig) if err != nil { t.Fatalf("client login error: %s", err) } // The error sequence is: // - no auth passed yet // - partial success // - wrong password // - wrong password // - nil if len(serverAuthErrors) != 5 { t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors) } if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok { t.Fatal("server not returned partial success") } if serverAuthErrors[2] != errPwdAuthFailed { t.Fatal("server not returned password authentication failed") } if serverAuthErrors[3] != errPwdAuthFailed { t.Fatal("server not returned password authentication failed") } // Only password authentication should fail. clientConfig.Auth = []AuthMethod{ Password(clientPassword), } serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig) if err == nil { t.Fatal("client login with password only must fail") } // The error sequence is: // - no auth passed yet // - wrong sequence if len(serverAuthErrors) != 2 { t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors) } if serverAuthErrors[1] != errWrongSequence { t.Fatal("server not returned wrong sequence") } // Only public key authentication should fail. clientConfig.Auth = []AuthMethod{ PublicKeys(testSigners["rsa"]), } serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig) if err == nil { t.Fatal("client login with public key only must fail") } // The error sequence is: // - no auth passed yet // - partial success if len(serverAuthErrors) != 2 { t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors) } if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok { t.Fatal("server not returned partial success") } // Public key and wrong password. clientConfig.Auth = []AuthMethod{ PublicKeys(testSigners["rsa"]), Password("WRONG"), } serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig) if err == nil { t.Fatal("client login with wrong password after public key must fail") } // The error sequence is: // - no auth passed yet // - partial success // - password auth failed if len(serverAuthErrors) != 3 { t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors) } if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok { t.Fatal("server not returned partial success") } if serverAuthErrors[2] != errPwdAuthFailed { t.Fatal("server not returned password authentication failed") } // Public key, public key again and then correct password. Public key // authentication is attempted only once because the partial success error // returns only "password" as the allowed authentication method. clientConfig.Auth = []AuthMethod{ PublicKeys(testSigners["rsa"]), PublicKeys(testSigners["rsa"]), Password(clientPassword), } serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig) if err != nil { t.Fatalf("client login error: %s", err) } // The error sequence is: // - no auth passed yet // - partial success // - nil if len(serverAuthErrors) != 3 { t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors) } if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok { t.Fatal("server not returned partial success") } // The unrestricted username can do anything clientConfig = &ClientConfig{ User: username, Auth: []AuthMethod{ PublicKeys(testSigners["rsa"]), Password(clientPassword), }, HostKeyCallback: InsecureIgnoreHostKey(), } _, err = doClientServerAuth(t, serverConfig, clientConfig) if err != nil { t.Fatalf("unrestricted client login error: %s", err) } clientConfig = &ClientConfig{ User: username, Auth: []AuthMethod{ PublicKeys(testSigners["rsa"]), }, HostKeyCallback: InsecureIgnoreHostKey(), } _, err = doClientServerAuth(t, serverConfig, clientConfig) if err != nil { t.Fatalf("unrestricted client login error: %s", err) } clientConfig = &ClientConfig{ User: username, Auth: []AuthMethod{ Password(clientPassword), }, HostKeyCallback: InsecureIgnoreHostKey(), } _, err = doClientServerAuth(t, serverConfig, clientConfig) if err != nil { t.Fatalf("unrestricted client login error: %s", err) } } func TestDynamicAuthCallbacks(t *testing.T) { user1 := "user1" user2 := "user2" errInvalidCredentials := errors.New("invalid credentials") serverConfig := &ServerConfig{ NoClientAuth: true, NoClientAuthCallback: func(conn ConnMetadata) (*Permissions, error) { switch conn.User() { case user1: return nil, &PartialSuccessError{ Next: ServerAuthCallbacks{ PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) { if conn.User() == user1 && string(password) == clientPassword { return nil, nil } return nil, errInvalidCredentials }, }, } case user2: return nil, &PartialSuccessError{ Next: ServerAuthCallbacks{ PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) { if bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) { if conn.User() == user2 { return nil, nil } } return nil, errInvalidCredentials }, }, } default: return nil, errInvalidCredentials } }, } clientConfig := &ClientConfig{ User: user1, Auth: []AuthMethod{ Password(clientPassword), }, HostKeyCallback: InsecureIgnoreHostKey(), } serverAuthErrors, err := doClientServerAuth(t, serverConfig, clientConfig) if err != nil { t.Fatalf("client login error: %s", err) } // The error sequence is: // - partial success // - nil if len(serverAuthErrors) != 2 { t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors) } if _, ok := serverAuthErrors[0].(*PartialSuccessError); !ok { t.Fatal("server not returned partial success") } clientConfig = &ClientConfig{ User: user2, Auth: []AuthMethod{ PublicKeys(testSigners["rsa"]), }, HostKeyCallback: InsecureIgnoreHostKey(), } serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig) if err != nil { t.Fatalf("client login error: %s", err) } // The error sequence is: // - partial success // - nil if len(serverAuthErrors) != 2 { t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors) } if _, ok := serverAuthErrors[0].(*PartialSuccessError); !ok { t.Fatal("server not returned partial success") } // user1 cannot login with public key clientConfig = &ClientConfig{ User: user1, Auth: []AuthMethod{ PublicKeys(testSigners["rsa"]), }, HostKeyCallback: InsecureIgnoreHostKey(), } serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig) if err == nil { t.Fatal("user1 login with public key must fail") } if !strings.Contains(err.Error(), "no supported methods remain") { t.Errorf("got %v, expected 'no supported methods remain'", err) } if len(serverAuthErrors) != 1 { t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors) } if _, ok := serverAuthErrors[0].(*PartialSuccessError); !ok { t.Fatal("server not returned partial success") } // user2 cannot login with password clientConfig = &ClientConfig{ User: user2, Auth: []AuthMethod{ Password(clientPassword), }, HostKeyCallback: InsecureIgnoreHostKey(), } serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig) if err == nil { t.Fatal("user2 login with password must fail") } if !strings.Contains(err.Error(), "no supported methods remain") { t.Errorf("got %v, expected 'no supported methods remain'", err) } if len(serverAuthErrors) != 1 { t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors) } if _, ok := serverAuthErrors[0].(*PartialSuccessError); !ok { t.Fatal("server not returned partial success") } }