mirror of
https://github.com/zitadel/zitadel.git
synced 2024-12-15 04:18:01 +00:00
fix: improve login_hint usage on IDPs (#6899)
* only set prompt if no login_hint is set
* update to current state and cleanup
(cherry picked from commit 0386fe7f96
)
This commit is contained in:
parent
af24208b38
commit
18788b6045
@ -1173,8 +1173,9 @@ func mapExternalNotFoundOptionFormDataToLoginUser(formData *externalNotFoundOpti
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Login) sessionParamsFromAuthRequest(ctx context.Context, authReq *domain.AuthRequest, identityProviderID string) []any {
|
func (l *Login) sessionParamsFromAuthRequest(ctx context.Context, authReq *domain.AuthRequest, identityProviderID string) []idp.Parameter {
|
||||||
params := []any{authReq.AgentID}
|
params := make([]idp.Parameter, 1, 2)
|
||||||
|
params[0] = idp.UserAgentID(authReq.AgentID)
|
||||||
|
|
||||||
if authReq.UserID != "" && identityProviderID != "" {
|
if authReq.UserID != "" && identityProviderID != "" {
|
||||||
links, err := l.getUserLinks(ctx, authReq.UserID, identityProviderID)
|
links, err := l.getUserLinks(ctx, authReq.UserID, identityProviderID)
|
||||||
@ -1183,27 +1184,21 @@ func (l *Login) sessionParamsFromAuthRequest(ctx context.Context, authReq *domai
|
|||||||
return params
|
return params
|
||||||
}
|
}
|
||||||
if len(links.Links) == 1 {
|
if len(links.Links) == 1 {
|
||||||
return append(params, keyAndValueToAuthURLOpt("login_hint", links.Links[0].ProvidedUsername))
|
return append(params, idp.LoginHintParam(links.Links[0].ProvidedUsername))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if authReq.UserName != "" {
|
if authReq.UserName != "" {
|
||||||
return append(params, keyAndValueToAuthURLOpt("login_hint", authReq.UserName))
|
return append(params, idp.LoginHintParam(authReq.UserName))
|
||||||
}
|
}
|
||||||
if authReq.LoginName != "" {
|
if authReq.LoginName != "" {
|
||||||
return append(params, keyAndValueToAuthURLOpt("login_hint", authReq.LoginName))
|
return append(params, idp.LoginHintParam(authReq.LoginName))
|
||||||
}
|
}
|
||||||
if authReq.LoginHint != "" {
|
if authReq.LoginHint != "" {
|
||||||
return append(params, keyAndValueToAuthURLOpt("login_hint", authReq.LoginHint))
|
return append(params, idp.LoginHintParam(authReq.LoginHint))
|
||||||
}
|
}
|
||||||
return params
|
return params
|
||||||
}
|
}
|
||||||
|
|
||||||
func keyAndValueToAuthURLOpt(key, value string) rp.AuthURLOpt {
|
|
||||||
return func() []oauth2.AuthCodeOption {
|
|
||||||
return []oauth2.AuthCodeOption{oauth2.SetAuthURLParam(key, value)}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *Login) getUserLinks(ctx context.Context, userID, idpID string) (*query.IDPUserLinks, error) {
|
func (l *Login) getUserLinks(ctx context.Context, userID, idpID string) (*query.IDPUserLinks, error) {
|
||||||
userIDQuery, err := query.NewIDPUserLinksUserIDSearchQuery(userID)
|
userIDQuery, err := query.NewIDPUserLinksUserIDSearchQuery(userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -11,7 +11,7 @@ import (
|
|||||||
// Provider is the minimal implementation for a 3rd party authentication provider
|
// Provider is the minimal implementation for a 3rd party authentication provider
|
||||||
type Provider interface {
|
type Provider interface {
|
||||||
Name() string
|
Name() string
|
||||||
BeginAuth(ctx context.Context, state string, params ...any) (Session, error)
|
BeginAuth(ctx context.Context, state string, params ...Parameter) (Session, error)
|
||||||
IsLinkingAllowed() bool
|
IsLinkingAllowed() bool
|
||||||
IsCreationAllowed() bool
|
IsCreationAllowed() bool
|
||||||
IsAutoCreation() bool
|
IsAutoCreation() bool
|
||||||
@ -34,3 +34,18 @@ type User interface {
|
|||||||
GetAvatarURL() string
|
GetAvatarURL() string
|
||||||
GetProfile() string
|
GetProfile() string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Parameter allows to pass specific parameter to the BeginAuth function
|
||||||
|
type Parameter interface {
|
||||||
|
setValue()
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserAgentID allows to pass the user agent ID of the auth request to BeginAuth
|
||||||
|
type UserAgentID string
|
||||||
|
|
||||||
|
func (p UserAgentID) setValue() {}
|
||||||
|
|
||||||
|
// LoginHintParam allows to pass a login_hint to BeginAuth
|
||||||
|
type LoginHintParam string
|
||||||
|
|
||||||
|
func (p LoginHintParam) setValue() {}
|
||||||
|
@ -91,13 +91,10 @@ func (p *Provider) Name() string {
|
|||||||
// BeginAuth implements the [idp.Provider] interface.
|
// BeginAuth implements the [idp.Provider] interface.
|
||||||
// It will create a [Session] with an AuthURL, pointing to the jwtEndpoint
|
// It will create a [Session] with an AuthURL, pointing to the jwtEndpoint
|
||||||
// with the authRequest and encrypted userAgent ids.
|
// with the authRequest and encrypted userAgent ids.
|
||||||
func (p *Provider) BeginAuth(ctx context.Context, state string, params ...any) (idp.Session, error) {
|
func (p *Provider) BeginAuth(ctx context.Context, state string, params ...idp.Parameter) (idp.Session, error) {
|
||||||
if len(params) < 1 {
|
userAgentID, err := userAgentIDFromParams(params...)
|
||||||
return nil, ErrMissingUserAgentID
|
if err != nil {
|
||||||
}
|
return nil, err
|
||||||
userAgentID, ok := params[0].(string)
|
|
||||||
if !ok {
|
|
||||||
return nil, ErrMissingUserAgentID
|
|
||||||
}
|
}
|
||||||
redirect, err := url.Parse(p.jwtEndpoint)
|
redirect, err := url.Parse(p.jwtEndpoint)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -114,6 +111,15 @@ func (p *Provider) BeginAuth(ctx context.Context, state string, params ...any) (
|
|||||||
return &Session{AuthURL: redirect.String()}, nil
|
return &Session{AuthURL: redirect.String()}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func userAgentIDFromParams(params ...idp.Parameter) (string, error) {
|
||||||
|
for _, param := range params {
|
||||||
|
if id, ok := param.(idp.UserAgentID); ok {
|
||||||
|
return string(id), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", ErrMissingUserAgentID
|
||||||
|
}
|
||||||
|
|
||||||
// IsLinkingAllowed implements the [idp.Provider] interface.
|
// IsLinkingAllowed implements the [idp.Provider] interface.
|
||||||
func (p *Provider) IsLinkingAllowed() bool {
|
func (p *Provider) IsLinkingAllowed() bool {
|
||||||
return p.isLinkingAllowed
|
return p.isLinkingAllowed
|
||||||
|
@ -23,7 +23,7 @@ func TestProvider_BeginAuth(t *testing.T) {
|
|||||||
encryptionAlg func(t *testing.T) crypto.EncryptionAlgorithm
|
encryptionAlg func(t *testing.T) crypto.EncryptionAlgorithm
|
||||||
}
|
}
|
||||||
type args struct {
|
type args struct {
|
||||||
params []any
|
params []idp.Parameter
|
||||||
}
|
}
|
||||||
type want struct {
|
type want struct {
|
||||||
session idp.Session
|
session idp.Session
|
||||||
@ -55,28 +55,6 @@ func TestProvider_BeginAuth(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
name: "invalid userAgentID error",
|
|
||||||
fields: fields{
|
|
||||||
issuer: "https://jwt.com",
|
|
||||||
jwtEndpoint: "https://auth.com/jwt",
|
|
||||||
keysEndpoint: "https://jwt.com/keys",
|
|
||||||
headerName: "jwt-header",
|
|
||||||
encryptionAlg: func(t *testing.T) crypto.EncryptionAlgorithm {
|
|
||||||
return crypto.CreateMockEncryptionAlg(gomock.NewController(t))
|
|
||||||
},
|
|
||||||
},
|
|
||||||
args: args{
|
|
||||||
params: []any{
|
|
||||||
0,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
want: want{
|
|
||||||
err: func(err error) bool {
|
|
||||||
return errors.Is(err, ErrMissingUserAgentID)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
name: "successful auth",
|
name: "successful auth",
|
||||||
fields: fields{
|
fields: fields{
|
||||||
@ -89,8 +67,8 @@ func TestProvider_BeginAuth(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
args: args{
|
args: args{
|
||||||
params: []any{
|
params: []idp.Parameter{
|
||||||
"agent",
|
idp.UserAgentID("agent"),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
want: want{
|
want: want{
|
||||||
|
@ -211,7 +211,7 @@ func (p *Provider) Name() string {
|
|||||||
return p.name
|
return p.name
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Provider) BeginAuth(ctx context.Context, state string, params ...any) (idp.Session, error) {
|
func (p *Provider) BeginAuth(ctx context.Context, state string, _ ...idp.Parameter) (idp.Session, error) {
|
||||||
return &Session{
|
return &Session{
|
||||||
Provider: p,
|
Provider: p,
|
||||||
loginUrl: p.loginUrl + state,
|
loginUrl: p.loginUrl + state,
|
||||||
|
@ -87,17 +87,28 @@ func (p *Provider) Name() string {
|
|||||||
|
|
||||||
// BeginAuth implements the [idp.Provider] interface.
|
// BeginAuth implements the [idp.Provider] interface.
|
||||||
// It will create a [Session] with an OAuth2.0 authorization request as AuthURL.
|
// It will create a [Session] with an OAuth2.0 authorization request as AuthURL.
|
||||||
func (p *Provider) BeginAuth(ctx context.Context, state string, params ...any) (idp.Session, error) {
|
func (p *Provider) BeginAuth(ctx context.Context, state string, params ...idp.Parameter) (idp.Session, error) {
|
||||||
opts := []rp.AuthURLOpt{rp.WithPrompt(oidc.PromptSelectAccount)}
|
opts := make([]rp.AuthURLOpt, 0)
|
||||||
|
var loginHintSet bool
|
||||||
for _, param := range params {
|
for _, param := range params {
|
||||||
if option, ok := param.(rp.AuthURLOpt); ok {
|
if username, ok := param.(idp.LoginHintParam); ok {
|
||||||
opts = append(opts, option)
|
loginHintSet = true
|
||||||
|
opts = append(opts, loginHint(string(username)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if !loginHintSet {
|
||||||
|
opts = append(opts, rp.WithPrompt(oidc.PromptSelectAccount))
|
||||||
|
}
|
||||||
url := rp.AuthURL(state, p.RelyingParty, opts...)
|
url := rp.AuthURL(state, p.RelyingParty, opts...)
|
||||||
return &Session{AuthURL: url, Provider: p}, nil
|
return &Session{AuthURL: url, Provider: p}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func loginHint(hint string) rp.AuthURLOpt {
|
||||||
|
return func() []oauth2.AuthCodeOption {
|
||||||
|
return []oauth2.AuthCodeOption{oauth2.SetAuthURLParam("login_hint", hint)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// IsLinkingAllowed implements the [idp.Provider] interface.
|
// IsLinkingAllowed implements the [idp.Provider] interface.
|
||||||
func (p *Provider) IsLinkingAllowed() bool {
|
func (p *Provider) IsLinkingAllowed() bool {
|
||||||
return p.isLinkingAllowed
|
return p.isLinkingAllowed
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
|
|
||||||
"github.com/zitadel/oidc/v3/pkg/client/rp"
|
"github.com/zitadel/oidc/v3/pkg/client/rp"
|
||||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
|
||||||
"github.com/zitadel/zitadel/internal/idp"
|
"github.com/zitadel/zitadel/internal/idp"
|
||||||
)
|
)
|
||||||
@ -22,7 +23,7 @@ type Provider struct {
|
|||||||
isAutoUpdate bool
|
isAutoUpdate bool
|
||||||
useIDToken bool
|
useIDToken bool
|
||||||
userInfoMapper func(info *oidc.UserInfo) idp.User
|
userInfoMapper func(info *oidc.UserInfo) idp.User
|
||||||
authOptions []rp.AuthURLOpt
|
authOptions []func(bool) rp.AuthURLOpt
|
||||||
}
|
}
|
||||||
|
|
||||||
type ProviderOpts func(provider *Provider)
|
type ProviderOpts func(provider *Provider)
|
||||||
@ -70,10 +71,15 @@ func WithRelyingPartyOption(option rp.Option) ProviderOpts {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithSelectAccount adds the select_account prompt to the auth request
|
// WithSelectAccount adds the select_account prompt to the auth request (if no login_hint is set)
|
||||||
func WithSelectAccount() ProviderOpts {
|
func WithSelectAccount() ProviderOpts {
|
||||||
return func(p *Provider) {
|
return func(p *Provider) {
|
||||||
p.authOptions = append(p.authOptions, rp.WithPrompt(oidc.PromptSelectAccount))
|
p.authOptions = append(p.authOptions, func(loginHintSet bool) rp.AuthURLOpt {
|
||||||
|
if loginHintSet {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return rp.WithPrompt(oidc.PromptSelectAccount)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -81,7 +87,9 @@ func WithSelectAccount() ProviderOpts {
|
|||||||
func WithResponseMode(mode oidc.ResponseMode) ProviderOpts {
|
func WithResponseMode(mode oidc.ResponseMode) ProviderOpts {
|
||||||
return func(p *Provider) {
|
return func(p *Provider) {
|
||||||
paramOpt := rp.WithResponseModeURLParam(mode)
|
paramOpt := rp.WithResponseModeURLParam(mode)
|
||||||
p.authOptions = append(p.authOptions, rp.AuthURLOpt(paramOpt))
|
p.authOptions = append(p.authOptions, func(_ bool) rp.AuthURLOpt {
|
||||||
|
return rp.AuthURLOpt(paramOpt)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -128,17 +136,30 @@ func (p *Provider) Name() string {
|
|||||||
|
|
||||||
// BeginAuth implements the [idp.Provider] interface.
|
// BeginAuth implements the [idp.Provider] interface.
|
||||||
// It will create a [Session] with an OIDC authorization request as AuthURL.
|
// It will create a [Session] with an OIDC authorization request as AuthURL.
|
||||||
func (p *Provider) BeginAuth(ctx context.Context, state string, params ...any) (idp.Session, error) {
|
func (p *Provider) BeginAuth(ctx context.Context, state string, params ...idp.Parameter) (idp.Session, error) {
|
||||||
opts := p.authOptions
|
opts := make([]rp.AuthURLOpt, 0)
|
||||||
|
var loginHintSet bool
|
||||||
for _, param := range params {
|
for _, param := range params {
|
||||||
if option, ok := param.(rp.AuthURLOpt); ok {
|
if username, ok := param.(idp.LoginHintParam); ok {
|
||||||
opts = append(opts, option)
|
loginHintSet = true
|
||||||
|
opts = append(opts, loginHint(string(username)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, option := range p.authOptions {
|
||||||
|
if opt := option(loginHintSet); opt != nil {
|
||||||
|
opts = append(opts, opt)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
url := rp.AuthURL(state, p.RelyingParty, opts...)
|
url := rp.AuthURL(state, p.RelyingParty, opts...)
|
||||||
return &Session{AuthURL: url, Provider: p}, nil
|
return &Session{AuthURL: url, Provider: p}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func loginHint(hint string) rp.AuthURLOpt {
|
||||||
|
return func() []oauth2.AuthCodeOption {
|
||||||
|
return []oauth2.AuthCodeOption{oauth2.SetAuthURLParam("login_hint", hint)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// IsLinkingAllowed implements the [idp.Provider] interface.
|
// IsLinkingAllowed implements the [idp.Provider] interface.
|
||||||
func (p *Provider) IsLinkingAllowed() bool {
|
func (p *Provider) IsLinkingAllowed() bool {
|
||||||
return p.isLinkingAllowed
|
return p.isLinkingAllowed
|
||||||
|
@ -162,7 +162,7 @@ func (p *Provider) GetSP() (*samlsp.Middleware, error) {
|
|||||||
return sp, nil
|
return sp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Provider) BeginAuth(ctx context.Context, state string, params ...any) (idp.Session, error) {
|
func (p *Provider) BeginAuth(ctx context.Context, state string, _ ...idp.Parameter) (idp.Session, error) {
|
||||||
m, err := p.GetSP()
|
m, err := p.GetSP()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
Loading…
Reference in New Issue
Block a user