fix: improve login_hint usage on IDPs (#6899)

* only set prompt if no login_hint is set

* update to current state and cleanup

(cherry picked from commit 0386fe7f96)
This commit is contained in:
Livio Spring 2023-11-13 10:25:26 +02:00
parent af24208b38
commit 18788b6045
No known key found for this signature in database
GPG Key ID: 26BB1C2FA5952CF0
8 changed files with 85 additions and 59 deletions

View File

@ -1173,8 +1173,9 @@ func mapExternalNotFoundOptionFormDataToLoginUser(formData *externalNotFoundOpti
}
}
func (l *Login) sessionParamsFromAuthRequest(ctx context.Context, authReq *domain.AuthRequest, identityProviderID string) []any {
params := []any{authReq.AgentID}
func (l *Login) sessionParamsFromAuthRequest(ctx context.Context, authReq *domain.AuthRequest, identityProviderID string) []idp.Parameter {
params := make([]idp.Parameter, 1, 2)
params[0] = idp.UserAgentID(authReq.AgentID)
if authReq.UserID != "" && identityProviderID != "" {
links, err := l.getUserLinks(ctx, authReq.UserID, identityProviderID)
@ -1183,27 +1184,21 @@ func (l *Login) sessionParamsFromAuthRequest(ctx context.Context, authReq *domai
return params
}
if len(links.Links) == 1 {
return append(params, keyAndValueToAuthURLOpt("login_hint", links.Links[0].ProvidedUsername))
return append(params, idp.LoginHintParam(links.Links[0].ProvidedUsername))
}
}
if authReq.UserName != "" {
return append(params, keyAndValueToAuthURLOpt("login_hint", authReq.UserName))
return append(params, idp.LoginHintParam(authReq.UserName))
}
if authReq.LoginName != "" {
return append(params, keyAndValueToAuthURLOpt("login_hint", authReq.LoginName))
return append(params, idp.LoginHintParam(authReq.LoginName))
}
if authReq.LoginHint != "" {
return append(params, keyAndValueToAuthURLOpt("login_hint", authReq.LoginHint))
return append(params, idp.LoginHintParam(authReq.LoginHint))
}
return params
}
func keyAndValueToAuthURLOpt(key, value string) rp.AuthURLOpt {
return func() []oauth2.AuthCodeOption {
return []oauth2.AuthCodeOption{oauth2.SetAuthURLParam(key, value)}
}
}
func (l *Login) getUserLinks(ctx context.Context, userID, idpID string) (*query.IDPUserLinks, error) {
userIDQuery, err := query.NewIDPUserLinksUserIDSearchQuery(userID)
if err != nil {

View File

@ -11,7 +11,7 @@ import (
// Provider is the minimal implementation for a 3rd party authentication provider
type Provider interface {
Name() string
BeginAuth(ctx context.Context, state string, params ...any) (Session, error)
BeginAuth(ctx context.Context, state string, params ...Parameter) (Session, error)
IsLinkingAllowed() bool
IsCreationAllowed() bool
IsAutoCreation() bool
@ -34,3 +34,18 @@ type User interface {
GetAvatarURL() string
GetProfile() string
}
// Parameter allows to pass specific parameter to the BeginAuth function
type Parameter interface {
setValue()
}
// UserAgentID allows to pass the user agent ID of the auth request to BeginAuth
type UserAgentID string
func (p UserAgentID) setValue() {}
// LoginHintParam allows to pass a login_hint to BeginAuth
type LoginHintParam string
func (p LoginHintParam) setValue() {}

View File

@ -91,13 +91,10 @@ func (p *Provider) Name() string {
// BeginAuth implements the [idp.Provider] interface.
// It will create a [Session] with an AuthURL, pointing to the jwtEndpoint
// with the authRequest and encrypted userAgent ids.
func (p *Provider) BeginAuth(ctx context.Context, state string, params ...any) (idp.Session, error) {
if len(params) < 1 {
return nil, ErrMissingUserAgentID
}
userAgentID, ok := params[0].(string)
if !ok {
return nil, ErrMissingUserAgentID
func (p *Provider) BeginAuth(ctx context.Context, state string, params ...idp.Parameter) (idp.Session, error) {
userAgentID, err := userAgentIDFromParams(params...)
if err != nil {
return nil, err
}
redirect, err := url.Parse(p.jwtEndpoint)
if err != nil {
@ -114,6 +111,15 @@ func (p *Provider) BeginAuth(ctx context.Context, state string, params ...any) (
return &Session{AuthURL: redirect.String()}, nil
}
func userAgentIDFromParams(params ...idp.Parameter) (string, error) {
for _, param := range params {
if id, ok := param.(idp.UserAgentID); ok {
return string(id), nil
}
}
return "", ErrMissingUserAgentID
}
// IsLinkingAllowed implements the [idp.Provider] interface.
func (p *Provider) IsLinkingAllowed() bool {
return p.isLinkingAllowed

View File

@ -23,7 +23,7 @@ func TestProvider_BeginAuth(t *testing.T) {
encryptionAlg func(t *testing.T) crypto.EncryptionAlgorithm
}
type args struct {
params []any
params []idp.Parameter
}
type want struct {
session idp.Session
@ -55,28 +55,6 @@ func TestProvider_BeginAuth(t *testing.T) {
},
},
},
{
name: "invalid userAgentID error",
fields: fields{
issuer: "https://jwt.com",
jwtEndpoint: "https://auth.com/jwt",
keysEndpoint: "https://jwt.com/keys",
headerName: "jwt-header",
encryptionAlg: func(t *testing.T) crypto.EncryptionAlgorithm {
return crypto.CreateMockEncryptionAlg(gomock.NewController(t))
},
},
args: args{
params: []any{
0,
},
},
want: want{
err: func(err error) bool {
return errors.Is(err, ErrMissingUserAgentID)
},
},
},
{
name: "successful auth",
fields: fields{
@ -89,8 +67,8 @@ func TestProvider_BeginAuth(t *testing.T) {
},
},
args: args{
params: []any{
"agent",
params: []idp.Parameter{
idp.UserAgentID("agent"),
},
},
want: want{

View File

@ -211,7 +211,7 @@ func (p *Provider) Name() string {
return p.name
}
func (p *Provider) BeginAuth(ctx context.Context, state string, params ...any) (idp.Session, error) {
func (p *Provider) BeginAuth(ctx context.Context, state string, _ ...idp.Parameter) (idp.Session, error) {
return &Session{
Provider: p,
loginUrl: p.loginUrl + state,

View File

@ -87,17 +87,28 @@ func (p *Provider) Name() string {
// BeginAuth implements the [idp.Provider] interface.
// It will create a [Session] with an OAuth2.0 authorization request as AuthURL.
func (p *Provider) BeginAuth(ctx context.Context, state string, params ...any) (idp.Session, error) {
opts := []rp.AuthURLOpt{rp.WithPrompt(oidc.PromptSelectAccount)}
func (p *Provider) BeginAuth(ctx context.Context, state string, params ...idp.Parameter) (idp.Session, error) {
opts := make([]rp.AuthURLOpt, 0)
var loginHintSet bool
for _, param := range params {
if option, ok := param.(rp.AuthURLOpt); ok {
opts = append(opts, option)
if username, ok := param.(idp.LoginHintParam); ok {
loginHintSet = true
opts = append(opts, loginHint(string(username)))
}
}
if !loginHintSet {
opts = append(opts, rp.WithPrompt(oidc.PromptSelectAccount))
}
url := rp.AuthURL(state, p.RelyingParty, opts...)
return &Session{AuthURL: url, Provider: p}, nil
}
func loginHint(hint string) rp.AuthURLOpt {
return func() []oauth2.AuthCodeOption {
return []oauth2.AuthCodeOption{oauth2.SetAuthURLParam("login_hint", hint)}
}
}
// IsLinkingAllowed implements the [idp.Provider] interface.
func (p *Provider) IsLinkingAllowed() bool {
return p.isLinkingAllowed

View File

@ -5,6 +5,7 @@ import (
"github.com/zitadel/oidc/v3/pkg/client/rp"
"github.com/zitadel/oidc/v3/pkg/oidc"
"golang.org/x/oauth2"
"github.com/zitadel/zitadel/internal/idp"
)
@ -22,7 +23,7 @@ type Provider struct {
isAutoUpdate bool
useIDToken bool
userInfoMapper func(info *oidc.UserInfo) idp.User
authOptions []rp.AuthURLOpt
authOptions []func(bool) rp.AuthURLOpt
}
type ProviderOpts func(provider *Provider)
@ -70,10 +71,15 @@ func WithRelyingPartyOption(option rp.Option) ProviderOpts {
}
}
// WithSelectAccount adds the select_account prompt to the auth request
// WithSelectAccount adds the select_account prompt to the auth request (if no login_hint is set)
func WithSelectAccount() ProviderOpts {
return func(p *Provider) {
p.authOptions = append(p.authOptions, rp.WithPrompt(oidc.PromptSelectAccount))
p.authOptions = append(p.authOptions, func(loginHintSet bool) rp.AuthURLOpt {
if loginHintSet {
return nil
}
return rp.WithPrompt(oidc.PromptSelectAccount)
})
}
}
@ -81,7 +87,9 @@ func WithSelectAccount() ProviderOpts {
func WithResponseMode(mode oidc.ResponseMode) ProviderOpts {
return func(p *Provider) {
paramOpt := rp.WithResponseModeURLParam(mode)
p.authOptions = append(p.authOptions, rp.AuthURLOpt(paramOpt))
p.authOptions = append(p.authOptions, func(_ bool) rp.AuthURLOpt {
return rp.AuthURLOpt(paramOpt)
})
}
}
@ -128,17 +136,30 @@ func (p *Provider) Name() string {
// BeginAuth implements the [idp.Provider] interface.
// It will create a [Session] with an OIDC authorization request as AuthURL.
func (p *Provider) BeginAuth(ctx context.Context, state string, params ...any) (idp.Session, error) {
opts := p.authOptions
func (p *Provider) BeginAuth(ctx context.Context, state string, params ...idp.Parameter) (idp.Session, error) {
opts := make([]rp.AuthURLOpt, 0)
var loginHintSet bool
for _, param := range params {
if option, ok := param.(rp.AuthURLOpt); ok {
opts = append(opts, option)
if username, ok := param.(idp.LoginHintParam); ok {
loginHintSet = true
opts = append(opts, loginHint(string(username)))
}
}
for _, option := range p.authOptions {
if opt := option(loginHintSet); opt != nil {
opts = append(opts, opt)
}
}
url := rp.AuthURL(state, p.RelyingParty, opts...)
return &Session{AuthURL: url, Provider: p}, nil
}
func loginHint(hint string) rp.AuthURLOpt {
return func() []oauth2.AuthCodeOption {
return []oauth2.AuthCodeOption{oauth2.SetAuthURLParam("login_hint", hint)}
}
}
// IsLinkingAllowed implements the [idp.Provider] interface.
func (p *Provider) IsLinkingAllowed() bool {
return p.isLinkingAllowed

View File

@ -162,7 +162,7 @@ func (p *Provider) GetSP() (*samlsp.Middleware, error) {
return sp, nil
}
func (p *Provider) BeginAuth(ctx context.Context, state string, params ...any) (idp.Session, error) {
func (p *Provider) BeginAuth(ctx context.Context, state string, _ ...idp.Parameter) (idp.Session, error) {
m, err := p.GetSP()
if err != nil {
return nil, err