From 7c592ce6387f40a529124782fff6eb7b6ded7561 Mon Sep 17 00:00:00 2001 From: Livio Spring Date: Wed, 10 Jan 2024 16:02:17 +0100 Subject: [PATCH] fix(idp): provide id_token for tenant id based azure ad (#7188) * fix(idp): provide id_token for tenant based azure ad * comments * remove unintentional changes --- internal/api/idp/idp.go | 2 +- .../api/ui/login/external_provider_handler.go | 4 +- internal/command/idp_intent.go | 2 +- internal/command/idp_intent_test.go | 5 +- internal/idp/providers/azuread/azuread.go | 12 ++++ internal/idp/providers/azuread/session.go | 64 ++++++++++++++++++- .../idp/providers/azuread/session_test.go | 25 +++++--- 7 files changed, 96 insertions(+), 18 deletions(-) diff --git a/internal/api/idp/idp.go b/internal/api/idp/idp.go index acae4f36b0..6a2cb6e32e 100644 --- a/internal/api/idp/idp.go +++ b/internal/api/idp/idp.go @@ -393,7 +393,7 @@ func (h *Handler) fetchIDPUserFromCode(ctx context.Context, identityProvider idp case *openid.Provider: session = &openid.Session{Provider: provider, Code: code} case *azuread.Provider: - session = &azuread.Session{Session: &oauth.Session{Provider: provider.Provider, Code: code}} + session = &azuread.Session{Provider: provider, Code: code} case *github.Provider: session = &oauth.Session{Provider: provider.Provider, Code: code} case *gitlab.Provider: diff --git a/internal/api/ui/login/external_provider_handler.go b/internal/api/ui/login/external_provider_handler.go index 532bd69e20..14fc930751 100644 --- a/internal/api/ui/login/external_provider_handler.go +++ b/internal/api/ui/login/external_provider_handler.go @@ -270,7 +270,7 @@ func (l *Login) handleExternalLoginCallback(w http.ResponseWriter, r *http.Reque l.externalAuthFailed(w, r, authReq, nil, nil, err) return } - session = &azuread.Session{Session: &oauth.Session{Provider: provider.(*azuread.Provider).Provider, Code: data.Code}} + session = &azuread.Session{Provider: provider.(*azuread.Provider), Code: data.Code} case domain.IDPTypeGitHub: provider, err = l.githubProvider(r.Context(), identityProvider) if err != nil { @@ -1132,7 +1132,7 @@ func tokens(session idp.Session) *oidc.Tokens[*oidc.IDTokenClaims] { case *oauth.Session: return s.Tokens case *azuread.Session: - return s.Tokens + return s.Tokens() case *apple.Session: return s.Tokens } diff --git a/internal/command/idp_intent.go b/internal/command/idp_intent.go index 1afc823fb2..38fc9c91b4 100644 --- a/internal/command/idp_intent.go +++ b/internal/command/idp_intent.go @@ -285,7 +285,7 @@ func tokensForSucceededIDPIntent(session idp.Session, encryptionAlg crypto.Encry case *jwt.Session: tokens = s.Tokens case *azuread.Session: - tokens = s.Tokens + tokens = s.Tokens() case *apple.Session: tokens = s.Tokens default: diff --git a/internal/command/idp_intent_test.go b/internal/command/idp_intent_test.go index a297548f6b..3043c7848a 100644 --- a/internal/command/idp_intent_test.go +++ b/internal/command/idp_intent_test.go @@ -1166,11 +1166,12 @@ func Test_tokensForSucceededIDPIntent(t *testing.T) { "azure tokens", args{ &azuread.Session{ - Session: &oauth.Session{ + OAuthSession: &oauth.Session{ Tokens: &oidc.Tokens[*oidc.IDTokenClaims]{ Token: &oauth2.Token{ AccessToken: "accessToken", }, + IDToken: "idToken", }, }, }, @@ -1183,7 +1184,7 @@ func Test_tokensForSucceededIDPIntent(t *testing.T) { KeyID: "id", Crypted: []byte("accessToken"), }, - idToken: "", + idToken: "idToken", err: nil, }, }, diff --git a/internal/idp/providers/azuread/azuread.go b/internal/idp/providers/azuread/azuread.go index 46445a3977..65f38ede5b 100644 --- a/internal/idp/providers/azuread/azuread.go +++ b/internal/idp/providers/azuread/azuread.go @@ -13,8 +13,10 @@ import ( ) const ( + issuerTemplate string = "https://login.microsoftonline.com/%s/v2.0" authURLTemplate string = "https://login.microsoftonline.com/%s/oauth2/v2.0/authorize" tokenURLTemplate string = "https://login.microsoftonline.com/%s/oauth2/v2.0/token" + keysURLTemplate string = "https://login.microsoftonline.com/%s/discovery/v2.0/keys" userURL string = "https://graph.microsoft.com/v1.0/me" userinfoEndpoint string = "https://graph.microsoft.com/oidc/userinfo" @@ -50,6 +52,16 @@ type Provider struct { options []oauth.ProviderOpts } +// issuer returns the OIDC issuer based on the [TenantType] +func (p *Provider) issuer() string { + return fmt.Sprintf(issuerTemplate, p.tenant) +} + +// keysEndpoint returns the OIDC jwks_url based on the [TenantType] +func (p *Provider) keysEndpoint() string { + return fmt.Sprintf(keysURLTemplate, p.tenant) +} + type ProviderOptions func(*Provider) // WithTenant allows to set a [TenantType] (can also be a Tenant ID) diff --git a/internal/idp/providers/azuread/session.go b/internal/idp/providers/azuread/session.go index 698cdad198..a9d8df2e8c 100644 --- a/internal/idp/providers/azuread/session.go +++ b/internal/idp/providers/azuread/session.go @@ -1,17 +1,28 @@ package azuread import ( + "context" "net/http" + "github.com/zitadel/oidc/v3/pkg/client/rp" httphelper "github.com/zitadel/oidc/v3/pkg/http" "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/zitadel/internal/idp" "github.com/zitadel/zitadel/internal/idp/providers/oauth" ) -// Session extends the [oauth.Session] to extend it with the [idp.SessionSupportsMigration] functionality +// Session extends the [oauth.Session] to be able to handle the id_token and to implement the [idp.SessionSupportsMigration] functionality type Session struct { - *oauth.Session + *Provider + Code string + + OAuthSession *oauth.Session +} + +// GetAuth implements the [idp.Provider] interface by calling the wrapped [oauth.Session]. +func (s *Session) GetAuth(ctx context.Context) (content string, redirect bool) { + return s.oauth().GetAuth(ctx) } // RetrievePreviousID implements the [idp.SessionSupportsMigration] interface by returning the `sub` from the userinfo endpoint @@ -20,10 +31,57 @@ func (s *Session) RetrievePreviousID() (string, error) { if err != nil { return "", err } - req.Header.Set("authorization", s.Tokens.TokenType+" "+s.Tokens.AccessToken) + req.Header.Set("authorization", s.oauth().Tokens.TokenType+" "+s.oauth().Tokens.AccessToken) userinfo := new(oidc.UserInfo) if err := httphelper.HttpRequest(s.Provider.HttpClient(), req, &userinfo); err != nil { return "", err } return userinfo.Subject, nil } + +// FetchUser implements the [idp.Session] interface. +// It will execute an OAuth 2.0 code exchange if needed to retrieve the access token, +// call the specified userEndpoint and map the received information into an [idp.User]. +// In case of a specific TenantID as [TenantType] it will additionally extract the id_token and validate it. +func (s *Session) FetchUser(ctx context.Context) (user idp.User, err error) { + user, err = s.oauth().FetchUser(ctx) + if err != nil { + return nil, err + } + // since azure will sign the id_token always with the issuer of the application it might differ from + // the issuer the auth and token were based on, e.g. when allowing all account types to login, + // then the auth endpoint must be `https://login.microsoftonline.com/common/oauth2/v2.0/authorize` + // even though the issuer would be like `https://login.microsoftonline.com/d8cdd43f-fd94-4576-8deb-f3bfea72dc2e/v2.0` + if s.Provider.tenant == CommonTenant || + s.Provider.tenant == OrganizationsTenant || + s.Provider.tenant == ConsumersTenant { + return user, nil + } + idToken, ok := s.oauth().Tokens.Extra("id_token").(string) + if !ok { + return user, nil + } + idTokenVerifier := rp.NewIDTokenVerifier(s.Provider.issuer(), s.Provider.OAuthConfig().ClientID, rp.NewRemoteKeySet(s.Provider.HttpClient(), s.Provider.keysEndpoint())) + s.oauth().Tokens.IDTokenClaims, err = rp.VerifyTokens[*oidc.IDTokenClaims](ctx, s.oauth().Tokens.AccessToken, idToken, idTokenVerifier) + if err != nil { + return nil, err + } + s.oauth().Tokens.IDToken = idToken + return user, nil +} + +// Tokens returns the [oidc.Tokens] of the underlying [oauth.Session]. +func (s *Session) Tokens() *oidc.Tokens[*oidc.IDTokenClaims] { + return s.oauth().Tokens +} + +func (s *Session) oauth() *oauth.Session { + if s.OAuthSession != nil { + return s.OAuthSession + } + s.OAuthSession = &oauth.Session{ + Code: s.Code, + Provider: s.Provider.Provider, + } + return s.OAuthSession +} diff --git a/internal/idp/providers/azuread/session_test.go b/internal/idp/providers/azuread/session_test.go index f68c4cc7d7..8215184ae3 100644 --- a/internal/idp/providers/azuread/session_test.go +++ b/internal/idp/providers/azuread/session_test.go @@ -247,12 +247,17 @@ func TestSession_FetchUser(t *testing.T) { provider, err := New(tt.fields.name, tt.fields.clientID, tt.fields.clientSecret, tt.fields.redirectURI, tt.fields.scopes, tt.fields.options...) require.NoError(t, err) - session := &Session{Session: &oauth.Session{ - AuthURL: tt.fields.authURL, + session := &Session{ + Provider: provider, Code: tt.fields.code, - Tokens: tt.fields.tokens, - Provider: provider.Provider, - }} + + OAuthSession: &oauth.Session{ + AuthURL: tt.fields.authURL, + Tokens: tt.fields.tokens, + Provider: provider.Provider, + Code: tt.fields.code, + }, + } user, err := session.FetchUser(context.Background()) if tt.want.err != nil && !tt.want.err(err) { @@ -392,10 +397,12 @@ func TestSession_RetrievePreviousID(t *testing.T) { provider, err := New(tt.fields.name, tt.fields.clientID, tt.fields.clientSecret, tt.fields.redirectURI, tt.fields.scopes) require.NoError(t, err) - session := &Session{Session: &oauth.Session{ - Tokens: tt.fields.tokens, - Provider: provider.Provider, - }} + session := &Session{ + Provider: provider, + OAuthSession: &oauth.Session{ + Tokens: tt.fields.tokens, + Provider: provider.Provider, + }} id, err := session.RetrievePreviousID() if tt.res.err {