mirror of
https://github.com/zitadel/zitadel.git
synced 2024-12-14 03:54:21 +00:00
fix: improve login_hint usage on IDPs (#6899)
* only set prompt if no login_hint is set
* update to current state and cleanup
(cherry picked from commit 0386fe7f96
)
This commit is contained in:
parent
af24208b38
commit
18788b6045
@ -1173,8 +1173,9 @@ func mapExternalNotFoundOptionFormDataToLoginUser(formData *externalNotFoundOpti
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Login) sessionParamsFromAuthRequest(ctx context.Context, authReq *domain.AuthRequest, identityProviderID string) []any {
|
||||
params := []any{authReq.AgentID}
|
||||
func (l *Login) sessionParamsFromAuthRequest(ctx context.Context, authReq *domain.AuthRequest, identityProviderID string) []idp.Parameter {
|
||||
params := make([]idp.Parameter, 1, 2)
|
||||
params[0] = idp.UserAgentID(authReq.AgentID)
|
||||
|
||||
if authReq.UserID != "" && identityProviderID != "" {
|
||||
links, err := l.getUserLinks(ctx, authReq.UserID, identityProviderID)
|
||||
@ -1183,27 +1184,21 @@ func (l *Login) sessionParamsFromAuthRequest(ctx context.Context, authReq *domai
|
||||
return params
|
||||
}
|
||||
if len(links.Links) == 1 {
|
||||
return append(params, keyAndValueToAuthURLOpt("login_hint", links.Links[0].ProvidedUsername))
|
||||
return append(params, idp.LoginHintParam(links.Links[0].ProvidedUsername))
|
||||
}
|
||||
}
|
||||
if authReq.UserName != "" {
|
||||
return append(params, keyAndValueToAuthURLOpt("login_hint", authReq.UserName))
|
||||
return append(params, idp.LoginHintParam(authReq.UserName))
|
||||
}
|
||||
if authReq.LoginName != "" {
|
||||
return append(params, keyAndValueToAuthURLOpt("login_hint", authReq.LoginName))
|
||||
return append(params, idp.LoginHintParam(authReq.LoginName))
|
||||
}
|
||||
if authReq.LoginHint != "" {
|
||||
return append(params, keyAndValueToAuthURLOpt("login_hint", authReq.LoginHint))
|
||||
return append(params, idp.LoginHintParam(authReq.LoginHint))
|
||||
}
|
||||
return params
|
||||
}
|
||||
|
||||
func keyAndValueToAuthURLOpt(key, value string) rp.AuthURLOpt {
|
||||
return func() []oauth2.AuthCodeOption {
|
||||
return []oauth2.AuthCodeOption{oauth2.SetAuthURLParam(key, value)}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Login) getUserLinks(ctx context.Context, userID, idpID string) (*query.IDPUserLinks, error) {
|
||||
userIDQuery, err := query.NewIDPUserLinksUserIDSearchQuery(userID)
|
||||
if err != nil {
|
||||
|
@ -11,7 +11,7 @@ import (
|
||||
// Provider is the minimal implementation for a 3rd party authentication provider
|
||||
type Provider interface {
|
||||
Name() string
|
||||
BeginAuth(ctx context.Context, state string, params ...any) (Session, error)
|
||||
BeginAuth(ctx context.Context, state string, params ...Parameter) (Session, error)
|
||||
IsLinkingAllowed() bool
|
||||
IsCreationAllowed() bool
|
||||
IsAutoCreation() bool
|
||||
@ -34,3 +34,18 @@ type User interface {
|
||||
GetAvatarURL() string
|
||||
GetProfile() string
|
||||
}
|
||||
|
||||
// Parameter allows to pass specific parameter to the BeginAuth function
|
||||
type Parameter interface {
|
||||
setValue()
|
||||
}
|
||||
|
||||
// UserAgentID allows to pass the user agent ID of the auth request to BeginAuth
|
||||
type UserAgentID string
|
||||
|
||||
func (p UserAgentID) setValue() {}
|
||||
|
||||
// LoginHintParam allows to pass a login_hint to BeginAuth
|
||||
type LoginHintParam string
|
||||
|
||||
func (p LoginHintParam) setValue() {}
|
||||
|
@ -91,13 +91,10 @@ func (p *Provider) Name() string {
|
||||
// BeginAuth implements the [idp.Provider] interface.
|
||||
// It will create a [Session] with an AuthURL, pointing to the jwtEndpoint
|
||||
// with the authRequest and encrypted userAgent ids.
|
||||
func (p *Provider) BeginAuth(ctx context.Context, state string, params ...any) (idp.Session, error) {
|
||||
if len(params) < 1 {
|
||||
return nil, ErrMissingUserAgentID
|
||||
}
|
||||
userAgentID, ok := params[0].(string)
|
||||
if !ok {
|
||||
return nil, ErrMissingUserAgentID
|
||||
func (p *Provider) BeginAuth(ctx context.Context, state string, params ...idp.Parameter) (idp.Session, error) {
|
||||
userAgentID, err := userAgentIDFromParams(params...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
redirect, err := url.Parse(p.jwtEndpoint)
|
||||
if err != nil {
|
||||
@ -114,6 +111,15 @@ func (p *Provider) BeginAuth(ctx context.Context, state string, params ...any) (
|
||||
return &Session{AuthURL: redirect.String()}, nil
|
||||
}
|
||||
|
||||
func userAgentIDFromParams(params ...idp.Parameter) (string, error) {
|
||||
for _, param := range params {
|
||||
if id, ok := param.(idp.UserAgentID); ok {
|
||||
return string(id), nil
|
||||
}
|
||||
}
|
||||
return "", ErrMissingUserAgentID
|
||||
}
|
||||
|
||||
// IsLinkingAllowed implements the [idp.Provider] interface.
|
||||
func (p *Provider) IsLinkingAllowed() bool {
|
||||
return p.isLinkingAllowed
|
||||
|
@ -23,7 +23,7 @@ func TestProvider_BeginAuth(t *testing.T) {
|
||||
encryptionAlg func(t *testing.T) crypto.EncryptionAlgorithm
|
||||
}
|
||||
type args struct {
|
||||
params []any
|
||||
params []idp.Parameter
|
||||
}
|
||||
type want struct {
|
||||
session idp.Session
|
||||
@ -55,28 +55,6 @@ func TestProvider_BeginAuth(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid userAgentID error",
|
||||
fields: fields{
|
||||
issuer: "https://jwt.com",
|
||||
jwtEndpoint: "https://auth.com/jwt",
|
||||
keysEndpoint: "https://jwt.com/keys",
|
||||
headerName: "jwt-header",
|
||||
encryptionAlg: func(t *testing.T) crypto.EncryptionAlgorithm {
|
||||
return crypto.CreateMockEncryptionAlg(gomock.NewController(t))
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
params: []any{
|
||||
0,
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
err: func(err error) bool {
|
||||
return errors.Is(err, ErrMissingUserAgentID)
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "successful auth",
|
||||
fields: fields{
|
||||
@ -89,8 +67,8 @@ func TestProvider_BeginAuth(t *testing.T) {
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
params: []any{
|
||||
"agent",
|
||||
params: []idp.Parameter{
|
||||
idp.UserAgentID("agent"),
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
|
@ -211,7 +211,7 @@ func (p *Provider) Name() string {
|
||||
return p.name
|
||||
}
|
||||
|
||||
func (p *Provider) BeginAuth(ctx context.Context, state string, params ...any) (idp.Session, error) {
|
||||
func (p *Provider) BeginAuth(ctx context.Context, state string, _ ...idp.Parameter) (idp.Session, error) {
|
||||
return &Session{
|
||||
Provider: p,
|
||||
loginUrl: p.loginUrl + state,
|
||||
|
@ -87,17 +87,28 @@ func (p *Provider) Name() string {
|
||||
|
||||
// BeginAuth implements the [idp.Provider] interface.
|
||||
// It will create a [Session] with an OAuth2.0 authorization request as AuthURL.
|
||||
func (p *Provider) BeginAuth(ctx context.Context, state string, params ...any) (idp.Session, error) {
|
||||
opts := []rp.AuthURLOpt{rp.WithPrompt(oidc.PromptSelectAccount)}
|
||||
func (p *Provider) BeginAuth(ctx context.Context, state string, params ...idp.Parameter) (idp.Session, error) {
|
||||
opts := make([]rp.AuthURLOpt, 0)
|
||||
var loginHintSet bool
|
||||
for _, param := range params {
|
||||
if option, ok := param.(rp.AuthURLOpt); ok {
|
||||
opts = append(opts, option)
|
||||
if username, ok := param.(idp.LoginHintParam); ok {
|
||||
loginHintSet = true
|
||||
opts = append(opts, loginHint(string(username)))
|
||||
}
|
||||
}
|
||||
if !loginHintSet {
|
||||
opts = append(opts, rp.WithPrompt(oidc.PromptSelectAccount))
|
||||
}
|
||||
url := rp.AuthURL(state, p.RelyingParty, opts...)
|
||||
return &Session{AuthURL: url, Provider: p}, nil
|
||||
}
|
||||
|
||||
func loginHint(hint string) rp.AuthURLOpt {
|
||||
return func() []oauth2.AuthCodeOption {
|
||||
return []oauth2.AuthCodeOption{oauth2.SetAuthURLParam("login_hint", hint)}
|
||||
}
|
||||
}
|
||||
|
||||
// IsLinkingAllowed implements the [idp.Provider] interface.
|
||||
func (p *Provider) IsLinkingAllowed() bool {
|
||||
return p.isLinkingAllowed
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
|
||||
"github.com/zitadel/oidc/v3/pkg/client/rp"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/idp"
|
||||
)
|
||||
@ -22,7 +23,7 @@ type Provider struct {
|
||||
isAutoUpdate bool
|
||||
useIDToken bool
|
||||
userInfoMapper func(info *oidc.UserInfo) idp.User
|
||||
authOptions []rp.AuthURLOpt
|
||||
authOptions []func(bool) rp.AuthURLOpt
|
||||
}
|
||||
|
||||
type ProviderOpts func(provider *Provider)
|
||||
@ -70,10 +71,15 @@ func WithRelyingPartyOption(option rp.Option) ProviderOpts {
|
||||
}
|
||||
}
|
||||
|
||||
// WithSelectAccount adds the select_account prompt to the auth request
|
||||
// WithSelectAccount adds the select_account prompt to the auth request (if no login_hint is set)
|
||||
func WithSelectAccount() ProviderOpts {
|
||||
return func(p *Provider) {
|
||||
p.authOptions = append(p.authOptions, rp.WithPrompt(oidc.PromptSelectAccount))
|
||||
p.authOptions = append(p.authOptions, func(loginHintSet bool) rp.AuthURLOpt {
|
||||
if loginHintSet {
|
||||
return nil
|
||||
}
|
||||
return rp.WithPrompt(oidc.PromptSelectAccount)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -81,7 +87,9 @@ func WithSelectAccount() ProviderOpts {
|
||||
func WithResponseMode(mode oidc.ResponseMode) ProviderOpts {
|
||||
return func(p *Provider) {
|
||||
paramOpt := rp.WithResponseModeURLParam(mode)
|
||||
p.authOptions = append(p.authOptions, rp.AuthURLOpt(paramOpt))
|
||||
p.authOptions = append(p.authOptions, func(_ bool) rp.AuthURLOpt {
|
||||
return rp.AuthURLOpt(paramOpt)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -128,17 +136,30 @@ func (p *Provider) Name() string {
|
||||
|
||||
// BeginAuth implements the [idp.Provider] interface.
|
||||
// It will create a [Session] with an OIDC authorization request as AuthURL.
|
||||
func (p *Provider) BeginAuth(ctx context.Context, state string, params ...any) (idp.Session, error) {
|
||||
opts := p.authOptions
|
||||
func (p *Provider) BeginAuth(ctx context.Context, state string, params ...idp.Parameter) (idp.Session, error) {
|
||||
opts := make([]rp.AuthURLOpt, 0)
|
||||
var loginHintSet bool
|
||||
for _, param := range params {
|
||||
if option, ok := param.(rp.AuthURLOpt); ok {
|
||||
opts = append(opts, option)
|
||||
if username, ok := param.(idp.LoginHintParam); ok {
|
||||
loginHintSet = true
|
||||
opts = append(opts, loginHint(string(username)))
|
||||
}
|
||||
}
|
||||
for _, option := range p.authOptions {
|
||||
if opt := option(loginHintSet); opt != nil {
|
||||
opts = append(opts, opt)
|
||||
}
|
||||
}
|
||||
url := rp.AuthURL(state, p.RelyingParty, opts...)
|
||||
return &Session{AuthURL: url, Provider: p}, nil
|
||||
}
|
||||
|
||||
func loginHint(hint string) rp.AuthURLOpt {
|
||||
return func() []oauth2.AuthCodeOption {
|
||||
return []oauth2.AuthCodeOption{oauth2.SetAuthURLParam("login_hint", hint)}
|
||||
}
|
||||
}
|
||||
|
||||
// IsLinkingAllowed implements the [idp.Provider] interface.
|
||||
func (p *Provider) IsLinkingAllowed() bool {
|
||||
return p.isLinkingAllowed
|
||||
|
@ -162,7 +162,7 @@ func (p *Provider) GetSP() (*samlsp.Middleware, error) {
|
||||
return sp, nil
|
||||
}
|
||||
|
||||
func (p *Provider) BeginAuth(ctx context.Context, state string, params ...any) (idp.Session, error) {
|
||||
func (p *Provider) BeginAuth(ctx context.Context, state string, _ ...idp.Parameter) (idp.Session, error) {
|
||||
m, err := p.GetSP()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
Loading…
Reference in New Issue
Block a user