fix: provide tokens in azuread idp session (#6334)

This commit is contained in:
Livio Spring 2023-08-08 11:28:47 +02:00 committed by GitHub
parent 605e683e29
commit 8dc1fd06a1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 31 additions and 0 deletions

View File

@ -967,6 +967,8 @@ func tokens(session idp.Session) *oidc.Tokens[*oidc.IDTokenClaims] {
return s.Tokens
case *oauth.Session:
return s.Tokens
case *azuread.Session:
return s.Tokens
}
return nil
}

View File

@ -14,6 +14,7 @@ import (
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/idp"
"github.com/zitadel/zitadel/internal/idp/providers/azuread"
"github.com/zitadel/zitadel/internal/idp/providers/jwt"
"github.com/zitadel/zitadel/internal/idp/providers/oauth"
openid "github.com/zitadel/zitadel/internal/idp/providers/oidc"
@ -165,6 +166,8 @@ func tokensForSucceededIDPIntent(session idp.Session, encryptionAlg crypto.Encry
tokens = s.Tokens
case *jwt.Session:
tokens = s.Tokens
case *azuread.Session:
tokens = s.Tokens
default:
return nil, "", nil
}

View File

@ -19,6 +19,7 @@ import (
"github.com/zitadel/zitadel/internal/id"
"github.com/zitadel/zitadel/internal/id/mock"
"github.com/zitadel/zitadel/internal/idp"
"github.com/zitadel/zitadel/internal/idp/providers/azuread"
"github.com/zitadel/zitadel/internal/idp/providers/jwt"
"github.com/zitadel/zitadel/internal/idp/providers/ldap"
"github.com/zitadel/zitadel/internal/idp/providers/oauth"
@ -745,6 +746,31 @@ func Test_tokensForSucceededIDPIntent(t *testing.T) {
err: nil,
},
},
{
"azure tokens",
args{
&azuread.Session{
Session: &oauth.Session{
Tokens: &oidc.Tokens[*oidc.IDTokenClaims]{
Token: &oauth2.Token{
AccessToken: "accessToken",
},
},
},
},
crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
},
res{
accessToken: &crypto.CryptoValue{
CryptoType: crypto.TypeEncryption,
Algorithm: "enc",
KeyID: "id",
Crypted: []byte("accessToken"),
},
idToken: "",
err: nil,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {