mirror of
https://github.com/zitadel/zitadel.git
synced 2025-01-06 13:37:40 +00:00
fix(idp): provide id_token for tenant id based azure ad (#7188)
* fix(idp): provide id_token for tenant based azure ad * comments * remove unintentional changes (cherry picked from commit 7c592ce6387f40a529124782fff6eb7b6ded7561)
This commit is contained in:
parent
a31191d9e2
commit
b63534c325
@ -393,7 +393,7 @@ func (h *Handler) fetchIDPUserFromCode(ctx context.Context, identityProvider idp
|
|||||||
case *openid.Provider:
|
case *openid.Provider:
|
||||||
session = &openid.Session{Provider: provider, Code: code}
|
session = &openid.Session{Provider: provider, Code: code}
|
||||||
case *azuread.Provider:
|
case *azuread.Provider:
|
||||||
session = &azuread.Session{Session: &oauth.Session{Provider: provider.Provider, Code: code}}
|
session = &azuread.Session{Provider: provider, Code: code}
|
||||||
case *github.Provider:
|
case *github.Provider:
|
||||||
session = &oauth.Session{Provider: provider.Provider, Code: code}
|
session = &oauth.Session{Provider: provider.Provider, Code: code}
|
||||||
case *gitlab.Provider:
|
case *gitlab.Provider:
|
||||||
|
@ -270,7 +270,7 @@ func (l *Login) handleExternalLoginCallback(w http.ResponseWriter, r *http.Reque
|
|||||||
l.externalAuthFailed(w, r, authReq, nil, nil, err)
|
l.externalAuthFailed(w, r, authReq, nil, nil, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
session = &azuread.Session{Session: &oauth.Session{Provider: provider.(*azuread.Provider).Provider, Code: data.Code}}
|
session = &azuread.Session{Provider: provider.(*azuread.Provider), Code: data.Code}
|
||||||
case domain.IDPTypeGitHub:
|
case domain.IDPTypeGitHub:
|
||||||
provider, err = l.githubProvider(r.Context(), identityProvider)
|
provider, err = l.githubProvider(r.Context(), identityProvider)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -1132,7 +1132,7 @@ func tokens(session idp.Session) *oidc.Tokens[*oidc.IDTokenClaims] {
|
|||||||
case *oauth.Session:
|
case *oauth.Session:
|
||||||
return s.Tokens
|
return s.Tokens
|
||||||
case *azuread.Session:
|
case *azuread.Session:
|
||||||
return s.Tokens
|
return s.Tokens()
|
||||||
case *apple.Session:
|
case *apple.Session:
|
||||||
return s.Tokens
|
return s.Tokens
|
||||||
}
|
}
|
||||||
|
@ -285,7 +285,7 @@ func tokensForSucceededIDPIntent(session idp.Session, encryptionAlg crypto.Encry
|
|||||||
case *jwt.Session:
|
case *jwt.Session:
|
||||||
tokens = s.Tokens
|
tokens = s.Tokens
|
||||||
case *azuread.Session:
|
case *azuread.Session:
|
||||||
tokens = s.Tokens
|
tokens = s.Tokens()
|
||||||
case *apple.Session:
|
case *apple.Session:
|
||||||
tokens = s.Tokens
|
tokens = s.Tokens
|
||||||
default:
|
default:
|
||||||
|
@ -1166,11 +1166,12 @@ func Test_tokensForSucceededIDPIntent(t *testing.T) {
|
|||||||
"azure tokens",
|
"azure tokens",
|
||||||
args{
|
args{
|
||||||
&azuread.Session{
|
&azuread.Session{
|
||||||
Session: &oauth.Session{
|
OAuthSession: &oauth.Session{
|
||||||
Tokens: &oidc.Tokens[*oidc.IDTokenClaims]{
|
Tokens: &oidc.Tokens[*oidc.IDTokenClaims]{
|
||||||
Token: &oauth2.Token{
|
Token: &oauth2.Token{
|
||||||
AccessToken: "accessToken",
|
AccessToken: "accessToken",
|
||||||
},
|
},
|
||||||
|
IDToken: "idToken",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -1183,7 +1184,7 @@ func Test_tokensForSucceededIDPIntent(t *testing.T) {
|
|||||||
KeyID: "id",
|
KeyID: "id",
|
||||||
Crypted: []byte("accessToken"),
|
Crypted: []byte("accessToken"),
|
||||||
},
|
},
|
||||||
idToken: "",
|
idToken: "idToken",
|
||||||
err: nil,
|
err: nil,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -13,8 +13,10 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
issuerTemplate string = "https://login.microsoftonline.com/%s/v2.0"
|
||||||
authURLTemplate string = "https://login.microsoftonline.com/%s/oauth2/v2.0/authorize"
|
authURLTemplate string = "https://login.microsoftonline.com/%s/oauth2/v2.0/authorize"
|
||||||
tokenURLTemplate string = "https://login.microsoftonline.com/%s/oauth2/v2.0/token"
|
tokenURLTemplate string = "https://login.microsoftonline.com/%s/oauth2/v2.0/token"
|
||||||
|
keysURLTemplate string = "https://login.microsoftonline.com/%s/discovery/v2.0/keys"
|
||||||
userURL string = "https://graph.microsoft.com/v1.0/me"
|
userURL string = "https://graph.microsoft.com/v1.0/me"
|
||||||
userinfoEndpoint string = "https://graph.microsoft.com/oidc/userinfo"
|
userinfoEndpoint string = "https://graph.microsoft.com/oidc/userinfo"
|
||||||
|
|
||||||
@ -50,6 +52,16 @@ type Provider struct {
|
|||||||
options []oauth.ProviderOpts
|
options []oauth.ProviderOpts
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// issuer returns the OIDC issuer based on the [TenantType]
|
||||||
|
func (p *Provider) issuer() string {
|
||||||
|
return fmt.Sprintf(issuerTemplate, p.tenant)
|
||||||
|
}
|
||||||
|
|
||||||
|
// keysEndpoint returns the OIDC jwks_url based on the [TenantType]
|
||||||
|
func (p *Provider) keysEndpoint() string {
|
||||||
|
return fmt.Sprintf(keysURLTemplate, p.tenant)
|
||||||
|
}
|
||||||
|
|
||||||
type ProviderOptions func(*Provider)
|
type ProviderOptions func(*Provider)
|
||||||
|
|
||||||
// WithTenant allows to set a [TenantType] (can also be a Tenant ID)
|
// WithTenant allows to set a [TenantType] (can also be a Tenant ID)
|
||||||
|
@ -1,17 +1,28 @@
|
|||||||
package azuread
|
package azuread
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/zitadel/oidc/v3/pkg/client/rp"
|
||||||
httphelper "github.com/zitadel/oidc/v3/pkg/http"
|
httphelper "github.com/zitadel/oidc/v3/pkg/http"
|
||||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||||
|
|
||||||
|
"github.com/zitadel/zitadel/internal/idp"
|
||||||
"github.com/zitadel/zitadel/internal/idp/providers/oauth"
|
"github.com/zitadel/zitadel/internal/idp/providers/oauth"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Session extends the [oauth.Session] to extend it with the [idp.SessionSupportsMigration] functionality
|
// Session extends the [oauth.Session] to be able to handle the id_token and to implement the [idp.SessionSupportsMigration] functionality
|
||||||
type Session struct {
|
type Session struct {
|
||||||
*oauth.Session
|
*Provider
|
||||||
|
Code string
|
||||||
|
|
||||||
|
OAuthSession *oauth.Session
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAuth implements the [idp.Provider] interface by calling the wrapped [oauth.Session].
|
||||||
|
func (s *Session) GetAuth(ctx context.Context) (content string, redirect bool) {
|
||||||
|
return s.oauth().GetAuth(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RetrievePreviousID implements the [idp.SessionSupportsMigration] interface by returning the `sub` from the userinfo endpoint
|
// RetrievePreviousID implements the [idp.SessionSupportsMigration] interface by returning the `sub` from the userinfo endpoint
|
||||||
@ -20,10 +31,57 @@ func (s *Session) RetrievePreviousID() (string, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
req.Header.Set("authorization", s.Tokens.TokenType+" "+s.Tokens.AccessToken)
|
req.Header.Set("authorization", s.oauth().Tokens.TokenType+" "+s.oauth().Tokens.AccessToken)
|
||||||
userinfo := new(oidc.UserInfo)
|
userinfo := new(oidc.UserInfo)
|
||||||
if err := httphelper.HttpRequest(s.Provider.HttpClient(), req, &userinfo); err != nil {
|
if err := httphelper.HttpRequest(s.Provider.HttpClient(), req, &userinfo); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
return userinfo.Subject, nil
|
return userinfo.Subject, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FetchUser implements the [idp.Session] interface.
|
||||||
|
// It will execute an OAuth 2.0 code exchange if needed to retrieve the access token,
|
||||||
|
// call the specified userEndpoint and map the received information into an [idp.User].
|
||||||
|
// In case of a specific TenantID as [TenantType] it will additionally extract the id_token and validate it.
|
||||||
|
func (s *Session) FetchUser(ctx context.Context) (user idp.User, err error) {
|
||||||
|
user, err = s.oauth().FetchUser(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// since azure will sign the id_token always with the issuer of the application it might differ from
|
||||||
|
// the issuer the auth and token were based on, e.g. when allowing all account types to login,
|
||||||
|
// then the auth endpoint must be `https://login.microsoftonline.com/common/oauth2/v2.0/authorize`
|
||||||
|
// even though the issuer would be like `https://login.microsoftonline.com/d8cdd43f-fd94-4576-8deb-f3bfea72dc2e/v2.0`
|
||||||
|
if s.Provider.tenant == CommonTenant ||
|
||||||
|
s.Provider.tenant == OrganizationsTenant ||
|
||||||
|
s.Provider.tenant == ConsumersTenant {
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
|
idToken, ok := s.oauth().Tokens.Extra("id_token").(string)
|
||||||
|
if !ok {
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
|
idTokenVerifier := rp.NewIDTokenVerifier(s.Provider.issuer(), s.Provider.OAuthConfig().ClientID, rp.NewRemoteKeySet(s.Provider.HttpClient(), s.Provider.keysEndpoint()))
|
||||||
|
s.oauth().Tokens.IDTokenClaims, err = rp.VerifyTokens[*oidc.IDTokenClaims](ctx, s.oauth().Tokens.AccessToken, idToken, idTokenVerifier)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
s.oauth().Tokens.IDToken = idToken
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tokens returns the [oidc.Tokens] of the underlying [oauth.Session].
|
||||||
|
func (s *Session) Tokens() *oidc.Tokens[*oidc.IDTokenClaims] {
|
||||||
|
return s.oauth().Tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) oauth() *oauth.Session {
|
||||||
|
if s.OAuthSession != nil {
|
||||||
|
return s.OAuthSession
|
||||||
|
}
|
||||||
|
s.OAuthSession = &oauth.Session{
|
||||||
|
Code: s.Code,
|
||||||
|
Provider: s.Provider.Provider,
|
||||||
|
}
|
||||||
|
return s.OAuthSession
|
||||||
|
}
|
||||||
|
@ -247,12 +247,17 @@ func TestSession_FetchUser(t *testing.T) {
|
|||||||
provider, err := New(tt.fields.name, tt.fields.clientID, tt.fields.clientSecret, tt.fields.redirectURI, tt.fields.scopes, tt.fields.options...)
|
provider, err := New(tt.fields.name, tt.fields.clientID, tt.fields.clientSecret, tt.fields.redirectURI, tt.fields.scopes, tt.fields.options...)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
session := &Session{Session: &oauth.Session{
|
session := &Session{
|
||||||
AuthURL: tt.fields.authURL,
|
Provider: provider,
|
||||||
Code: tt.fields.code,
|
Code: tt.fields.code,
|
||||||
Tokens: tt.fields.tokens,
|
|
||||||
Provider: provider.Provider,
|
OAuthSession: &oauth.Session{
|
||||||
}}
|
AuthURL: tt.fields.authURL,
|
||||||
|
Tokens: tt.fields.tokens,
|
||||||
|
Provider: provider.Provider,
|
||||||
|
Code: tt.fields.code,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
user, err := session.FetchUser(context.Background())
|
user, err := session.FetchUser(context.Background())
|
||||||
if tt.want.err != nil && !tt.want.err(err) {
|
if tt.want.err != nil && !tt.want.err(err) {
|
||||||
@ -392,10 +397,12 @@ func TestSession_RetrievePreviousID(t *testing.T) {
|
|||||||
|
|
||||||
provider, err := New(tt.fields.name, tt.fields.clientID, tt.fields.clientSecret, tt.fields.redirectURI, tt.fields.scopes)
|
provider, err := New(tt.fields.name, tt.fields.clientID, tt.fields.clientSecret, tt.fields.redirectURI, tt.fields.scopes)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
session := &Session{Session: &oauth.Session{
|
session := &Session{
|
||||||
Tokens: tt.fields.tokens,
|
Provider: provider,
|
||||||
Provider: provider.Provider,
|
OAuthSession: &oauth.Session{
|
||||||
}}
|
Tokens: tt.fields.tokens,
|
||||||
|
Provider: provider.Provider,
|
||||||
|
}}
|
||||||
|
|
||||||
id, err := session.RetrievePreviousID()
|
id, err := session.RetrievePreviousID()
|
||||||
if tt.res.err {
|
if tt.res.err {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user