From f99cf50f69e7f19a779a32ea7fcc3002616652b6 Mon Sep 17 00:00:00 2001 From: Stefan Benz <46600784+stebenz@users.noreply.github.com> Date: Tue, 14 Mar 2023 16:42:29 +0100 Subject: [PATCH] fix: add authURLParams to urls for external idps (#5404) add authURL parameters to urls for external IDPs, depended on the contents of the authRequest --------- Co-authored-by: Livio Spring --- .../api/ui/login/external_provider_handler.go | 56 ++++++++++++++++++- internal/idp/providers/jwt/jwt.go | 2 +- internal/idp/providers/oauth/oauth2.go | 10 +++- internal/idp/providers/oidc/oidc.go | 10 +++- 4 files changed, 71 insertions(+), 7 deletions(-) diff --git a/internal/api/ui/login/external_provider_handler.go b/internal/api/ui/login/external_provider_handler.go index 5ea64e0bae..f887b57ee6 100644 --- a/internal/api/ui/login/external_provider_handler.go +++ b/internal/api/ui/login/external_provider_handler.go @@ -136,6 +136,7 @@ func (l *Login) handleIDP(w http.ResponseWriter, r *http.Request, authReq *domai return } var provider idp.Provider + switch identityProvider.Type { case domain.IDPTypeOAuth: provider, err = l.oauthProvider(r.Context(), identityProvider) @@ -165,7 +166,8 @@ func (l *Login) handleIDP(w http.ResponseWriter, r *http.Request, authReq *domai l.renderLogin(w, r, authReq, err) return } - session, err := provider.BeginAuth(r.Context(), authReq.ID, authReq.AgentID) + params := l.sessionParamsFromAuthRequest(r.Context(), authReq, identityProvider.ID) + session, err := provider.BeginAuth(r.Context(), authReq.ID, params...) if err != nil { l.renderLogin(w, r, authReq, err) return @@ -801,7 +803,7 @@ func mapExternalUserToLoginUser(externalUser *domain.ExternalUser, mustBeDomain externalIDP := &domain.UserIDPLink{ IDPConfigID: externalUser.IDPConfigID, ExternalUserID: externalUser.ExternalUserID, - DisplayName: externalUser.DisplayName, + DisplayName: externalUser.PreferredUsername, } return human, externalIDP, externalUser.Metadatas } @@ -824,3 +826,53 @@ func mapExternalNotFoundOptionFormDataToLoginUser(formData *externalNotFoundOpti PreferredLanguage: language.Make(formData.Language), } } + +func (l *Login) sessionParamsFromAuthRequest(ctx context.Context, authReq *domain.AuthRequest, identityProviderID string) []any { + params := []any{authReq.AgentID} + + if authReq.UserID != "" && identityProviderID != "" { + links, err := l.getUserLinks(ctx, authReq.UserID, identityProviderID) + if err != nil { + logging.WithFields("authReqID", authReq.ID, "userID", authReq.UserID, "providerID", identityProviderID).WithError(err).Warn("failed to get user links for") + return params + } + if len(links.Links) == 1 { + return append(params, keyAndValueToAuthURLOpt("login_hint", links.Links[0].ProvidedUsername)) + } + } + if authReq.UserName != "" { + return append(params, keyAndValueToAuthURLOpt("login_hint", authReq.UserName)) + } + if authReq.LoginName != "" { + return append(params, keyAndValueToAuthURLOpt("login_hint", authReq.LoginName)) + } + if authReq.LoginHint != "" { + return append(params, keyAndValueToAuthURLOpt("login_hint", authReq.LoginHint)) + } + return params +} + +func keyAndValueToAuthURLOpt(key, value string) rp.AuthURLOpt { + return func() []oauth2.AuthCodeOption { + return []oauth2.AuthCodeOption{oauth2.SetAuthURLParam(key, value)} + } +} + +func (l *Login) getUserLinks(ctx context.Context, userID, idpID string) (*query.IDPUserLinks, error) { + userIDQuery, err := query.NewIDPUserLinksUserIDSearchQuery(userID) + if err != nil { + return nil, err + } + idpIDQuery, err := query.NewIDPUserLinkIDPIDSearchQuery(idpID) + if err != nil { + return nil, err + } + return l.query.IDPUserLinks(ctx, + &query.IDPUserLinksSearchQuery{ + Queries: []query.SearchQuery{ + userIDQuery, + idpIDQuery, + }, + }, false, + ) +} diff --git a/internal/idp/providers/jwt/jwt.go b/internal/idp/providers/jwt/jwt.go index bd2effac8c..40e4a8e153 100644 --- a/internal/idp/providers/jwt/jwt.go +++ b/internal/idp/providers/jwt/jwt.go @@ -92,7 +92,7 @@ func (p *Provider) Name() string { // It will create a [Session] with an AuthURL, pointing to the jwtEndpoint // with the authRequest and encrypted userAgent ids. func (p *Provider) BeginAuth(ctx context.Context, state string, params ...any) (idp.Session, error) { - if len(params) != 1 { + if len(params) < 1 { return nil, ErrMissingUserAgentID } userAgentID, ok := params[0].(string) diff --git a/internal/idp/providers/oauth/oauth2.go b/internal/idp/providers/oauth/oauth2.go index 774a4dcec5..a31e9d4c26 100644 --- a/internal/idp/providers/oauth/oauth2.go +++ b/internal/idp/providers/oauth/oauth2.go @@ -87,8 +87,14 @@ func (p *Provider) Name() string { // BeginAuth implements the [idp.Provider] interface. // It will create a [Session] with an OAuth2.0 authorization request as AuthURL. -func (p *Provider) BeginAuth(ctx context.Context, state string, _ ...any) (idp.Session, error) { - url := rp.AuthURL(state, p.RelyingParty, rp.WithPrompt(oidc.PromptSelectAccount)) +func (p *Provider) BeginAuth(ctx context.Context, state string, params ...any) (idp.Session, error) { + opts := []rp.AuthURLOpt{rp.WithPrompt(oidc.PromptSelectAccount)} + for _, param := range params { + if option, ok := param.(rp.AuthURLOpt); ok { + opts = append(opts, option) + } + } + url := rp.AuthURL(state, p.RelyingParty, opts...) return &Session{AuthURL: url, Provider: p}, nil } diff --git a/internal/idp/providers/oidc/oidc.go b/internal/idp/providers/oidc/oidc.go index c8f86dd7a3..2a0c305bdf 100644 --- a/internal/idp/providers/oidc/oidc.go +++ b/internal/idp/providers/oidc/oidc.go @@ -112,8 +112,14 @@ func (p *Provider) Name() string { // BeginAuth implements the [idp.Provider] interface. // It will create a [Session] with an OIDC authorization request as AuthURL. -func (p *Provider) BeginAuth(ctx context.Context, state string, _ ...any) (idp.Session, error) { - url := rp.AuthURL(state, p.RelyingParty, p.authOptions...) +func (p *Provider) BeginAuth(ctx context.Context, state string, params ...any) (idp.Session, error) { + opts := p.authOptions + for _, param := range params { + if option, ok := param.(rp.AuthURLOpt); ok { + opts = append(opts, option) + } + } + url := rp.AuthURL(state, p.RelyingParty, opts...) return &Session{AuthURL: url, Provider: p}, nil }