From 966df560264fb4779d32817a583993134fd23feb Mon Sep 17 00:00:00 2001 From: Livio Spring Date: Wed, 1 Mar 2023 08:17:51 +0100 Subject: [PATCH] fix(providers): set prompt select_account again (#5329) --- internal/api/ui/login/external_provider_handler.go | 4 ++-- internal/api/ui/login/jwt_handler.go | 2 +- internal/idp/providers/azuread/azuread_test.go | 4 ++-- internal/idp/providers/github/github_test.go | 2 +- internal/idp/providers/gitlab/gitlab_test.go | 2 +- internal/idp/providers/google/google_test.go | 2 +- internal/idp/providers/oauth/oauth2.go | 3 ++- internal/idp/providers/oauth/oauth2_test.go | 2 +- internal/idp/providers/oidc/oidc.go | 2 +- internal/idp/providers/oidc/oidc_test.go | 2 +- 10 files changed, 13 insertions(+), 12 deletions(-) diff --git a/internal/api/ui/login/external_provider_handler.go b/internal/api/ui/login/external_provider_handler.go index 9e6d7d3787..19d572ae88 100644 --- a/internal/api/ui/login/external_provider_handler.go +++ b/internal/api/ui/login/external_provider_handler.go @@ -137,7 +137,7 @@ func (l *Login) handleIDP(w http.ResponseWriter, r *http.Request, authReq *domai case domain.IDPTypeOIDC: provider, err = l.oidcProvider(r.Context(), identityProvider) case domain.IDPTypeJWT: - provider, err = l.jwtProvider(r.Context(), identityProvider) + provider, err = l.jwtProvider(identityProvider) case domain.IDPTypeGoogle: provider, err = l.googleProvider(r.Context(), identityProvider) case domain.IDPTypeOAuth, @@ -589,7 +589,7 @@ func (l *Login) oidcProvider(ctx context.Context, identityProvider *query.IDPTem ) } -func (l *Login) jwtProvider(ctx context.Context, identityProvider *query.IDPTemplate) (*jwt.Provider, error) { +func (l *Login) jwtProvider(identityProvider *query.IDPTemplate) (*jwt.Provider, error) { return jwt.New( identityProvider.Name, identityProvider.JWTIDPTemplate.Issuer, diff --git a/internal/api/ui/login/jwt_handler.go b/internal/api/ui/login/jwt_handler.go index 94c5a39e01..8cef3ed747 100644 --- a/internal/api/ui/login/jwt_handler.go +++ b/internal/api/ui/login/jwt_handler.go @@ -74,7 +74,7 @@ func (l *Login) handleJWTExtraction(w http.ResponseWriter, r *http.Request, auth l.renderError(w, r, authReq, err) return } - provider, err := l.jwtProvider(r.Context(), identityProvider) + provider, err := l.jwtProvider(identityProvider) if err != nil { emptyTokens := &oidc.Tokens{Token: &oauth2.Token{}} if _, actionErr := l.runPostExternalAuthenticationActions(&domain.ExternalUser{}, emptyTokens, authReq, r, err); actionErr != nil { diff --git a/internal/idp/providers/azuread/azuread_test.go b/internal/idp/providers/azuread/azuread_test.go index dbd18a9ae9..f257d9dc4a 100644 --- a/internal/idp/providers/azuread/azuread_test.go +++ b/internal/idp/providers/azuread/azuread_test.go @@ -34,7 +34,7 @@ func TestProvider_BeginAuth(t *testing.T) { redirectURI: "redirectURI", }, want: &oidc.Session{ - AuthURL: "https://login.microsoftonline.com/common/oauth2/v2.0/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid+profile+email&state=testState", + AuthURL: "https://login.microsoftonline.com/common/oauth2/v2.0/authorize?client_id=clientID&prompt=select_account&redirect_uri=redirectURI&response_type=code&scope=openid+profile+email&state=testState", }, }, { @@ -48,7 +48,7 @@ func TestProvider_BeginAuth(t *testing.T) { }, }, want: &oidc.Session{ - AuthURL: "https://login.microsoftonline.com/consumers/oauth2/v2.0/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid+profile+email&state=testState", + AuthURL: "https://login.microsoftonline.com/consumers/oauth2/v2.0/authorize?client_id=clientID&prompt=select_account&redirect_uri=redirectURI&response_type=code&scope=openid+profile+email&state=testState", }, }, } diff --git a/internal/idp/providers/github/github_test.go b/internal/idp/providers/github/github_test.go index bc61bb5120..63244f68d5 100644 --- a/internal/idp/providers/github/github_test.go +++ b/internal/idp/providers/github/github_test.go @@ -32,7 +32,7 @@ func TestProvider_BeginAuth(t *testing.T) { redirectURI: "redirectURI", }, want: &oauth.Session{ - AuthURL: "https://github.com/login/oauth/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&state=testState", + AuthURL: "https://github.com/login/oauth/authorize?client_id=clientID&prompt=select_account&redirect_uri=redirectURI&response_type=code&state=testState", }, }, } diff --git a/internal/idp/providers/gitlab/gitlab_test.go b/internal/idp/providers/gitlab/gitlab_test.go index 436a00b2ca..c6660a3e99 100644 --- a/internal/idp/providers/gitlab/gitlab_test.go +++ b/internal/idp/providers/gitlab/gitlab_test.go @@ -33,7 +33,7 @@ func TestProvider_BeginAuth(t *testing.T) { scopes: []string{"openid"}, }, want: &oidc.Session{ - AuthURL: "https://gitlab.com/oauth/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid&state=testState", + AuthURL: "https://gitlab.com/oauth/authorize?client_id=clientID&prompt=select_account&redirect_uri=redirectURI&response_type=code&scope=openid&state=testState", }, }, } diff --git a/internal/idp/providers/google/google_test.go b/internal/idp/providers/google/google_test.go index ff0a6d5d49..d7fb3ecb81 100644 --- a/internal/idp/providers/google/google_test.go +++ b/internal/idp/providers/google/google_test.go @@ -32,7 +32,7 @@ func TestProvider_BeginAuth(t *testing.T) { scopes: []string{"openid"}, }, want: &oidc.Session{ - AuthURL: "https://accounts.google.com/o/oauth2/v2/auth?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid&state=testState", + AuthURL: "https://accounts.google.com/o/oauth2/v2/auth?client_id=clientID&prompt=select_account&redirect_uri=redirectURI&response_type=code&scope=openid&state=testState", }, }, } diff --git a/internal/idp/providers/oauth/oauth2.go b/internal/idp/providers/oauth/oauth2.go index 4799131922..774a4dcec5 100644 --- a/internal/idp/providers/oauth/oauth2.go +++ b/internal/idp/providers/oauth/oauth2.go @@ -4,6 +4,7 @@ import ( "context" "github.com/zitadel/oidc/v2/pkg/client/rp" + "github.com/zitadel/oidc/v2/pkg/oidc" "golang.org/x/oauth2" "github.com/zitadel/zitadel/internal/idp" @@ -87,7 +88,7 @@ func (p *Provider) Name() string { // BeginAuth implements the [idp.Provider] interface. // It will create a [Session] with an OAuth2.0 authorization request as AuthURL. func (p *Provider) BeginAuth(ctx context.Context, state string, _ ...any) (idp.Session, error) { - url := rp.AuthURL(state, p.RelyingParty) + url := rp.AuthURL(state, p.RelyingParty, rp.WithPrompt(oidc.PromptSelectAccount)) return &Session{AuthURL: url, Provider: p}, nil } diff --git a/internal/idp/providers/oauth/oauth2_test.go b/internal/idp/providers/oauth/oauth2_test.go index 84eca5509f..d145745918 100644 --- a/internal/idp/providers/oauth/oauth2_test.go +++ b/internal/idp/providers/oauth/oauth2_test.go @@ -38,7 +38,7 @@ func TestProvider_BeginAuth(t *testing.T) { Scopes: []string{"user"}, }, }, - want: &Session{AuthURL: "https://oauth2.com/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=user&state=testState"}, + want: &Session{AuthURL: "https://oauth2.com/authorize?client_id=clientID&prompt=select_account&redirect_uri=redirectURI&response_type=code&scope=user&state=testState"}, }, } for _, tt := range tests { diff --git a/internal/idp/providers/oidc/oidc.go b/internal/idp/providers/oidc/oidc.go index 890b168cb5..81d77d5a32 100644 --- a/internal/idp/providers/oidc/oidc.go +++ b/internal/idp/providers/oidc/oidc.go @@ -105,7 +105,7 @@ func (p *Provider) Name() string { // BeginAuth implements the [idp.Provider] interface. // It will create a [Session] with an OIDC authorization request as AuthURL. func (p *Provider) BeginAuth(ctx context.Context, state string, _ ...any) (idp.Session, error) { - url := rp.AuthURL(state, p.RelyingParty) + url := rp.AuthURL(state, p.RelyingParty, rp.WithPrompt(oidc.PromptSelectAccount)) return &Session{AuthURL: url, Provider: p}, nil } diff --git a/internal/idp/providers/oidc/oidc_test.go b/internal/idp/providers/oidc/oidc_test.go index ca205d7e5b..7875f241b0 100644 --- a/internal/idp/providers/oidc/oidc_test.go +++ b/internal/idp/providers/oidc/oidc_test.go @@ -51,7 +51,7 @@ func TestProvider_BeginAuth(t *testing.T) { }) }, }, - want: &Session{AuthURL: "https://issuer.com/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid&state=testState"}, + want: &Session{AuthURL: "https://issuer.com/authorize?client_id=clientID&prompt=select_account&redirect_uri=redirectURI&response_type=code&scope=openid&state=testState"}, }, } for _, tt := range tests {