fix: improve login_hint usage on IDPs (#6899)

* only set prompt if no login_hint is set

* update to current state and cleanup

(cherry picked from commit 0386fe7f96)
This commit is contained in:
Livio Spring 2023-11-13 10:25:26 +02:00
parent af24208b38
commit 18788b6045
No known key found for this signature in database
GPG Key ID: 26BB1C2FA5952CF0
8 changed files with 85 additions and 59 deletions

View File

@ -1173,8 +1173,9 @@ func mapExternalNotFoundOptionFormDataToLoginUser(formData *externalNotFoundOpti
} }
} }
func (l *Login) sessionParamsFromAuthRequest(ctx context.Context, authReq *domain.AuthRequest, identityProviderID string) []any { func (l *Login) sessionParamsFromAuthRequest(ctx context.Context, authReq *domain.AuthRequest, identityProviderID string) []idp.Parameter {
params := []any{authReq.AgentID} params := make([]idp.Parameter, 1, 2)
params[0] = idp.UserAgentID(authReq.AgentID)
if authReq.UserID != "" && identityProviderID != "" { if authReq.UserID != "" && identityProviderID != "" {
links, err := l.getUserLinks(ctx, authReq.UserID, identityProviderID) links, err := l.getUserLinks(ctx, authReq.UserID, identityProviderID)
@ -1183,27 +1184,21 @@ func (l *Login) sessionParamsFromAuthRequest(ctx context.Context, authReq *domai
return params return params
} }
if len(links.Links) == 1 { if len(links.Links) == 1 {
return append(params, keyAndValueToAuthURLOpt("login_hint", links.Links[0].ProvidedUsername)) return append(params, idp.LoginHintParam(links.Links[0].ProvidedUsername))
} }
} }
if authReq.UserName != "" { if authReq.UserName != "" {
return append(params, keyAndValueToAuthURLOpt("login_hint", authReq.UserName)) return append(params, idp.LoginHintParam(authReq.UserName))
} }
if authReq.LoginName != "" { if authReq.LoginName != "" {
return append(params, keyAndValueToAuthURLOpt("login_hint", authReq.LoginName)) return append(params, idp.LoginHintParam(authReq.LoginName))
} }
if authReq.LoginHint != "" { if authReq.LoginHint != "" {
return append(params, keyAndValueToAuthURLOpt("login_hint", authReq.LoginHint)) return append(params, idp.LoginHintParam(authReq.LoginHint))
} }
return params return params
} }
func keyAndValueToAuthURLOpt(key, value string) rp.AuthURLOpt {
return func() []oauth2.AuthCodeOption {
return []oauth2.AuthCodeOption{oauth2.SetAuthURLParam(key, value)}
}
}
func (l *Login) getUserLinks(ctx context.Context, userID, idpID string) (*query.IDPUserLinks, error) { func (l *Login) getUserLinks(ctx context.Context, userID, idpID string) (*query.IDPUserLinks, error) {
userIDQuery, err := query.NewIDPUserLinksUserIDSearchQuery(userID) userIDQuery, err := query.NewIDPUserLinksUserIDSearchQuery(userID)
if err != nil { if err != nil {

View File

@ -11,7 +11,7 @@ import (
// Provider is the minimal implementation for a 3rd party authentication provider // Provider is the minimal implementation for a 3rd party authentication provider
type Provider interface { type Provider interface {
Name() string Name() string
BeginAuth(ctx context.Context, state string, params ...any) (Session, error) BeginAuth(ctx context.Context, state string, params ...Parameter) (Session, error)
IsLinkingAllowed() bool IsLinkingAllowed() bool
IsCreationAllowed() bool IsCreationAllowed() bool
IsAutoCreation() bool IsAutoCreation() bool
@ -34,3 +34,18 @@ type User interface {
GetAvatarURL() string GetAvatarURL() string
GetProfile() string GetProfile() string
} }
// Parameter allows to pass specific parameter to the BeginAuth function
type Parameter interface {
setValue()
}
// UserAgentID allows to pass the user agent ID of the auth request to BeginAuth
type UserAgentID string
func (p UserAgentID) setValue() {}
// LoginHintParam allows to pass a login_hint to BeginAuth
type LoginHintParam string
func (p LoginHintParam) setValue() {}

View File

@ -91,13 +91,10 @@ func (p *Provider) Name() string {
// BeginAuth implements the [idp.Provider] interface. // BeginAuth implements the [idp.Provider] interface.
// It will create a [Session] with an AuthURL, pointing to the jwtEndpoint // It will create a [Session] with an AuthURL, pointing to the jwtEndpoint
// with the authRequest and encrypted userAgent ids. // with the authRequest and encrypted userAgent ids.
func (p *Provider) BeginAuth(ctx context.Context, state string, params ...any) (idp.Session, error) { func (p *Provider) BeginAuth(ctx context.Context, state string, params ...idp.Parameter) (idp.Session, error) {
if len(params) < 1 { userAgentID, err := userAgentIDFromParams(params...)
return nil, ErrMissingUserAgentID if err != nil {
} return nil, err
userAgentID, ok := params[0].(string)
if !ok {
return nil, ErrMissingUserAgentID
} }
redirect, err := url.Parse(p.jwtEndpoint) redirect, err := url.Parse(p.jwtEndpoint)
if err != nil { if err != nil {
@ -114,6 +111,15 @@ func (p *Provider) BeginAuth(ctx context.Context, state string, params ...any) (
return &Session{AuthURL: redirect.String()}, nil return &Session{AuthURL: redirect.String()}, nil
} }
func userAgentIDFromParams(params ...idp.Parameter) (string, error) {
for _, param := range params {
if id, ok := param.(idp.UserAgentID); ok {
return string(id), nil
}
}
return "", ErrMissingUserAgentID
}
// IsLinkingAllowed implements the [idp.Provider] interface. // IsLinkingAllowed implements the [idp.Provider] interface.
func (p *Provider) IsLinkingAllowed() bool { func (p *Provider) IsLinkingAllowed() bool {
return p.isLinkingAllowed return p.isLinkingAllowed

View File

@ -23,7 +23,7 @@ func TestProvider_BeginAuth(t *testing.T) {
encryptionAlg func(t *testing.T) crypto.EncryptionAlgorithm encryptionAlg func(t *testing.T) crypto.EncryptionAlgorithm
} }
type args struct { type args struct {
params []any params []idp.Parameter
} }
type want struct { type want struct {
session idp.Session session idp.Session
@ -55,28 +55,6 @@ func TestProvider_BeginAuth(t *testing.T) {
}, },
}, },
}, },
{
name: "invalid userAgentID error",
fields: fields{
issuer: "https://jwt.com",
jwtEndpoint: "https://auth.com/jwt",
keysEndpoint: "https://jwt.com/keys",
headerName: "jwt-header",
encryptionAlg: func(t *testing.T) crypto.EncryptionAlgorithm {
return crypto.CreateMockEncryptionAlg(gomock.NewController(t))
},
},
args: args{
params: []any{
0,
},
},
want: want{
err: func(err error) bool {
return errors.Is(err, ErrMissingUserAgentID)
},
},
},
{ {
name: "successful auth", name: "successful auth",
fields: fields{ fields: fields{
@ -89,8 +67,8 @@ func TestProvider_BeginAuth(t *testing.T) {
}, },
}, },
args: args{ args: args{
params: []any{ params: []idp.Parameter{
"agent", idp.UserAgentID("agent"),
}, },
}, },
want: want{ want: want{

View File

@ -211,7 +211,7 @@ func (p *Provider) Name() string {
return p.name return p.name
} }
func (p *Provider) BeginAuth(ctx context.Context, state string, params ...any) (idp.Session, error) { func (p *Provider) BeginAuth(ctx context.Context, state string, _ ...idp.Parameter) (idp.Session, error) {
return &Session{ return &Session{
Provider: p, Provider: p,
loginUrl: p.loginUrl + state, loginUrl: p.loginUrl + state,

View File

@ -87,17 +87,28 @@ func (p *Provider) Name() string {
// BeginAuth implements the [idp.Provider] interface. // BeginAuth implements the [idp.Provider] interface.
// It will create a [Session] with an OAuth2.0 authorization request as AuthURL. // It will create a [Session] with an OAuth2.0 authorization request as AuthURL.
func (p *Provider) BeginAuth(ctx context.Context, state string, params ...any) (idp.Session, error) { func (p *Provider) BeginAuth(ctx context.Context, state string, params ...idp.Parameter) (idp.Session, error) {
opts := []rp.AuthURLOpt{rp.WithPrompt(oidc.PromptSelectAccount)} opts := make([]rp.AuthURLOpt, 0)
var loginHintSet bool
for _, param := range params { for _, param := range params {
if option, ok := param.(rp.AuthURLOpt); ok { if username, ok := param.(idp.LoginHintParam); ok {
opts = append(opts, option) loginHintSet = true
opts = append(opts, loginHint(string(username)))
} }
} }
if !loginHintSet {
opts = append(opts, rp.WithPrompt(oidc.PromptSelectAccount))
}
url := rp.AuthURL(state, p.RelyingParty, opts...) url := rp.AuthURL(state, p.RelyingParty, opts...)
return &Session{AuthURL: url, Provider: p}, nil return &Session{AuthURL: url, Provider: p}, nil
} }
func loginHint(hint string) rp.AuthURLOpt {
return func() []oauth2.AuthCodeOption {
return []oauth2.AuthCodeOption{oauth2.SetAuthURLParam("login_hint", hint)}
}
}
// IsLinkingAllowed implements the [idp.Provider] interface. // IsLinkingAllowed implements the [idp.Provider] interface.
func (p *Provider) IsLinkingAllowed() bool { func (p *Provider) IsLinkingAllowed() bool {
return p.isLinkingAllowed return p.isLinkingAllowed

View File

@ -5,6 +5,7 @@ import (
"github.com/zitadel/oidc/v3/pkg/client/rp" "github.com/zitadel/oidc/v3/pkg/client/rp"
"github.com/zitadel/oidc/v3/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/oidc"
"golang.org/x/oauth2"
"github.com/zitadel/zitadel/internal/idp" "github.com/zitadel/zitadel/internal/idp"
) )
@ -22,7 +23,7 @@ type Provider struct {
isAutoUpdate bool isAutoUpdate bool
useIDToken bool useIDToken bool
userInfoMapper func(info *oidc.UserInfo) idp.User userInfoMapper func(info *oidc.UserInfo) idp.User
authOptions []rp.AuthURLOpt authOptions []func(bool) rp.AuthURLOpt
} }
type ProviderOpts func(provider *Provider) type ProviderOpts func(provider *Provider)
@ -70,10 +71,15 @@ func WithRelyingPartyOption(option rp.Option) ProviderOpts {
} }
} }
// WithSelectAccount adds the select_account prompt to the auth request // WithSelectAccount adds the select_account prompt to the auth request (if no login_hint is set)
func WithSelectAccount() ProviderOpts { func WithSelectAccount() ProviderOpts {
return func(p *Provider) { return func(p *Provider) {
p.authOptions = append(p.authOptions, rp.WithPrompt(oidc.PromptSelectAccount)) p.authOptions = append(p.authOptions, func(loginHintSet bool) rp.AuthURLOpt {
if loginHintSet {
return nil
}
return rp.WithPrompt(oidc.PromptSelectAccount)
})
} }
} }
@ -81,7 +87,9 @@ func WithSelectAccount() ProviderOpts {
func WithResponseMode(mode oidc.ResponseMode) ProviderOpts { func WithResponseMode(mode oidc.ResponseMode) ProviderOpts {
return func(p *Provider) { return func(p *Provider) {
paramOpt := rp.WithResponseModeURLParam(mode) paramOpt := rp.WithResponseModeURLParam(mode)
p.authOptions = append(p.authOptions, rp.AuthURLOpt(paramOpt)) p.authOptions = append(p.authOptions, func(_ bool) rp.AuthURLOpt {
return rp.AuthURLOpt(paramOpt)
})
} }
} }
@ -128,17 +136,30 @@ func (p *Provider) Name() string {
// BeginAuth implements the [idp.Provider] interface. // BeginAuth implements the [idp.Provider] interface.
// It will create a [Session] with an OIDC authorization request as AuthURL. // It will create a [Session] with an OIDC authorization request as AuthURL.
func (p *Provider) BeginAuth(ctx context.Context, state string, params ...any) (idp.Session, error) { func (p *Provider) BeginAuth(ctx context.Context, state string, params ...idp.Parameter) (idp.Session, error) {
opts := p.authOptions opts := make([]rp.AuthURLOpt, 0)
var loginHintSet bool
for _, param := range params { for _, param := range params {
if option, ok := param.(rp.AuthURLOpt); ok { if username, ok := param.(idp.LoginHintParam); ok {
opts = append(opts, option) loginHintSet = true
opts = append(opts, loginHint(string(username)))
}
}
for _, option := range p.authOptions {
if opt := option(loginHintSet); opt != nil {
opts = append(opts, opt)
} }
} }
url := rp.AuthURL(state, p.RelyingParty, opts...) url := rp.AuthURL(state, p.RelyingParty, opts...)
return &Session{AuthURL: url, Provider: p}, nil return &Session{AuthURL: url, Provider: p}, nil
} }
func loginHint(hint string) rp.AuthURLOpt {
return func() []oauth2.AuthCodeOption {
return []oauth2.AuthCodeOption{oauth2.SetAuthURLParam("login_hint", hint)}
}
}
// IsLinkingAllowed implements the [idp.Provider] interface. // IsLinkingAllowed implements the [idp.Provider] interface.
func (p *Provider) IsLinkingAllowed() bool { func (p *Provider) IsLinkingAllowed() bool {
return p.isLinkingAllowed return p.isLinkingAllowed

View File

@ -162,7 +162,7 @@ func (p *Provider) GetSP() (*samlsp.Middleware, error) {
return sp, nil return sp, nil
} }
func (p *Provider) BeginAuth(ctx context.Context, state string, params ...any) (idp.Session, error) { func (p *Provider) BeginAuth(ctx context.Context, state string, _ ...idp.Parameter) (idp.Session, error) {
m, err := p.GetSP() m, err := p.GetSP()
if err != nil { if err != nil {
return nil, err return nil, err