fix(oidc): respect role assertion and idTokenInfo flags and trigger preAccessToken trigger (#8046)

# Which Problems Are Solved

After deployment of 2.53.x, customers noted that the roles claims where
always present in the tokens even if the corresponding option on the
client (accessTokenRoleAssertion, idTokenRoleAsseriton) was disabled.
Only the project flag (assertRolesOnAuthentication) would be considered.

Further it was noted, that the action on the preAccessTokenCreation
trigger would not be executed.

Additionally, while testing those issues we found out, that the user
information (name, givenname, family name, ...) where always present in
the id_token even if the option (idTokenUserInfo) was not enabled.

# How the Problems Are Solved

- The `getUserinfoOnce` which was used for access and id_tokens is
refactored to `getUserInfo` and no longer only queries the info once
from the database, but still provides a mechanism to be reused for
access and id_token where the corresponding `roleAssertion` and action
`triggerType` can be passed.
- `userInfo` on the other hand now directly makes sure the information
is only queried once from the database. Role claims are asserted every
time and action triggers are executed on every call.
- `userInfo` now also checks if the profile information need to be
returned.

# Additional Changes

None.

# Additional Context

- relates to #7822 
- reported by customers
This commit is contained in:
Livio Spring 2024-05-31 12:10:18 +02:00 committed by GitHub
parent bc885632fb
commit f065b42a97
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 263 additions and 195 deletions

View File

@ -471,7 +471,7 @@ func (s *Server) CreateTokenCallbackURL(ctx context.Context, req op.AuthRequest)
if err != nil {
return "", err
}
resp, err := s.accessTokenResponseFromSession(ctx, client, session, state, client.client.ProjectID, client.client.ProjectRoleAssertion)
resp, err := s.accessTokenResponseFromSession(ctx, client, session, state, client.client.ProjectID, client.client.ProjectRoleAssertion, client.client.AccessTokenRoleAssertion, client.client.IDTokenRoleAssertion, client.client.IDTokenUserinfoAssertion)
if err != nil {
return "", err
}
@ -563,7 +563,7 @@ func (s *Server) authResponseToken(authReq *AuthRequest, authorizer op.Authorize
op.AuthRequestError(w, r, authReq, err, authorizer)
return err
}
resp, err := s.accessTokenResponseFromSession(ctx, client, session, authReq.GetState(), client.client.ProjectID, client.client.ProjectRoleAssertion)
resp, err := s.accessTokenResponseFromSession(ctx, client, session, authReq.GetState(), client.client.ProjectID, client.client.ProjectRoleAssertion, client.client.AccessTokenRoleAssertion, client.client.IDTokenRoleAssertion, client.client.IDTokenUserinfoAssertion)
if err != nil {
op.AuthRequestError(w, r, authReq, err, authorizer)
return err

View File

@ -54,7 +54,7 @@ func TestOPStorage_CreateAccessToken_code(t *testing.T) {
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err)
assertTokens(t, tokens, false)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
// callback on a succeeded request must fail
linkResp, err = Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{
@ -108,7 +108,7 @@ func TestOPStorage_CreateAccessToken_implicit(t *testing.T) {
require.NoError(t, err)
claims, err := rp.VerifyTokens[*oidc.IDTokenClaims](context.Background(), accessToken, idToken, provider.IDTokenVerifier())
require.NoError(t, err)
assertIDTokenClaims(t, claims, User.GetUserId(), armPasskey, startTime, changeTime)
assertIDTokenClaims(t, claims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
// callback on a succeeded request must fail
linkResp, err = Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{
@ -143,7 +143,7 @@ func TestOPStorage_CreateAccessAndRefreshTokens_code(t *testing.T) {
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err)
assertTokens(t, tokens, true)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
}
func TestOPStorage_CreateAccessAndRefreshTokens_refresh(t *testing.T) {
@ -168,14 +168,14 @@ func TestOPStorage_CreateAccessAndRefreshTokens_refresh(t *testing.T) {
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err)
assertTokens(t, tokens, true)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
// test actual refresh grant
newTokens, err := refreshTokens(t, clientID, tokens.RefreshToken)
require.NoError(t, err)
assertTokens(t, newTokens, true)
// auth time must still be the initial
assertIDTokenClaims(t, newTokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
assertIDTokenClaims(t, newTokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
// refresh with an old refresh_token must fail
_, err = rp.RefreshTokens[*oidc.IDTokenClaims](CTX, provider, tokens.RefreshToken, "", "")
@ -204,7 +204,7 @@ func TestOPStorage_RevokeToken_access_token(t *testing.T) {
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err)
assertTokens(t, tokens, true)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
// revoke access token
err = rp.RevokeToken(CTX, provider, tokens.AccessToken, "access_token")
@ -247,7 +247,7 @@ func TestOPStorage_RevokeToken_access_token_invalid_token_hint_type(t *testing.T
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err)
assertTokens(t, tokens, true)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
// revoke access token
err = rp.RevokeToken(CTX, provider, tokens.AccessToken, "refresh_token")
@ -284,7 +284,7 @@ func TestOPStorage_RevokeToken_refresh_token(t *testing.T) {
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err)
assertTokens(t, tokens, true)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
// revoke refresh token -> invalidates also access token
err = rp.RevokeToken(CTX, provider, tokens.RefreshToken, "refresh_token")
@ -327,7 +327,7 @@ func TestOPStorage_RevokeToken_refresh_token_invalid_token_type_hint(t *testing.
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err)
assertTokens(t, tokens, true)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
// revoke refresh token even with a wrong hint
err = rp.RevokeToken(CTX, provider, tokens.RefreshToken, "access_token")
@ -362,7 +362,7 @@ func TestOPStorage_RevokeToken_invalid_client(t *testing.T) {
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err)
assertTokens(t, tokens, true)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
// simulate second client (not part of the audience) trying to revoke the token
otherClientID, _ := createClient(t)
@ -394,7 +394,7 @@ func TestOPStorage_TerminateSession(t *testing.T) {
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err)
assertTokens(t, tokens, false)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
// userinfo must not fail
_, err = rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.Subject, provider)
@ -431,7 +431,7 @@ func TestOPStorage_TerminateSession_refresh_grant(t *testing.T) {
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err)
assertTokens(t, tokens, true)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
// userinfo must not fail
_, err = rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.Subject, provider)
@ -475,7 +475,7 @@ func TestOPStorage_TerminateSession_empty_id_token_hint(t *testing.T) {
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err)
assertTokens(t, tokens, false)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
postLogoutRedirect, err := rp.EndSession(CTX, provider, "", logoutRedirectURI, "state")
require.NoError(t, err)
@ -530,8 +530,13 @@ func assertTokens(t *testing.T, tokens *oidc.Tokens[*oidc.IDTokenClaims], requir
}
}
func assertIDTokenClaims(t *testing.T, claims *oidc.IDTokenClaims, userID string, arm []string, sessionStart, sessionChange time.Time) {
func assertIDTokenClaims(t *testing.T, claims *oidc.IDTokenClaims, userID string, arm []string, sessionStart, sessionChange time.Time, sessionID string) {
assert.Equal(t, userID, claims.Subject)
assert.Equal(t, arm, claims.AuthenticationMethodsReferences)
assertOIDCTimeRange(t, claims.AuthTime, sessionStart, sessionChange)
assert.Equal(t, sessionID, claims.SessionID)
assert.Empty(t, claims.Name)
assert.Empty(t, claims.GivenName)
assert.Empty(t, claims.FamilyName)
assert.Empty(t, claims.PreferredUsername)
}

View File

@ -122,7 +122,7 @@ func TestServer_Introspect(t *testing.T) {
tokens, err := exchangeTokens(t, app.GetClientId(), code, redirectURI)
require.NoError(t, err)
assertTokens(t, tokens, true)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
// test actual introspection
introspection, err := rs.Introspect[*oidc.IntrospectionResponse](context.Background(), resourceServer, tokens.AccessToken)
@ -317,7 +317,7 @@ func TestServer_VerifyClient(t *testing.T) {
}
require.NoError(t, err)
assertTokens(t, tokens, false)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
})
}
}

View File

@ -11,6 +11,7 @@ import (
"github.com/zitadel/oidc/v3/pkg/op"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
"github.com/zitadel/zitadel/internal/zerrors"
@ -103,7 +104,14 @@ func (s *Server) Introspect(ctx context.Context, r *op.Request[op.IntrospectionR
if err = validateIntrospectionAudience(token.audience, client.clientID, client.projectID); err != nil {
return nil, err
}
userInfo, err := s.userInfo(ctx, token.userID, token.scope, client.projectID, client.projectRoleAssertion, true)
userInfo, err := s.userInfo(
token.userID,
token.scope,
client.projectID,
client.projectRoleAssertion,
true,
true,
)(ctx, true, domain.TriggerTypePreUserinfoCreation)
if err != nil {
return nil, err
}

View File

@ -77,7 +77,7 @@ func Test_ZITADEL_API_missing_audience_scope(t *testing.T) {
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err)
assertTokens(t, tokens, false)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken))
@ -142,7 +142,7 @@ func Test_ZITADEL_API_missing_mfa_2fa_setup(t *testing.T) {
code := assertCodeResponse(t, linkResp.GetCallbackUrl())
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err)
assertIDTokenClaims(t, tokens.IDTokenClaims, userResp.GetUserId(), armPassword, startTime, changeTime)
assertIDTokenClaims(t, tokens.IDTokenClaims, userResp.GetUserId(), armPassword, startTime, changeTime, sessionID)
ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken))
@ -173,7 +173,7 @@ func Test_ZITADEL_API_missing_mfa_policy(t *testing.T) {
code := assertCodeResponse(t, linkResp.GetCallbackUrl())
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err)
assertIDTokenClaims(t, tokens.IDTokenClaims, userID, armPassword, startTime, changeTime)
assertIDTokenClaims(t, tokens.IDTokenClaims, userID, armPassword, startTime, changeTime, sessionID)
ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken))
@ -227,7 +227,7 @@ func Test_ZITADEL_API_success(t *testing.T) {
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err)
assertTokens(t, tokens, false)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken))
@ -261,7 +261,7 @@ func Test_ZITADEL_API_glob_redirects(t *testing.T) {
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err)
assertTokens(t, tokens, false)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken))
@ -290,7 +290,7 @@ func Test_ZITADEL_API_inactive_access_token(t *testing.T) {
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err)
assertTokens(t, tokens, true)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
// make sure token works
ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken))
@ -332,7 +332,7 @@ func Test_ZITADEL_API_terminated_session(t *testing.T) {
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err)
assertTokens(t, tokens, true)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
// make sure token works
ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken))
@ -402,7 +402,7 @@ func Test_ZITADEL_API_terminated_session_user_disabled(t *testing.T) {
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err)
assertTokens(t, tokens, true)
assertIDTokenClaims(t, tokens.IDTokenClaims, disabledUser.GetUserId(), armPassword, startTime, changeTime)
assertIDTokenClaims(t, tokens.IDTokenClaims, disabledUser.GetUserId(), armPassword, startTime, changeTime, sessionID)
// make sure token works
ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken))

View File

@ -29,8 +29,8 @@ In some cases step 1 till 3 are completely implemented in the command package,
for example the v2 code exchange and refresh token.
*/
func (s *Server) accessTokenResponseFromSession(ctx context.Context, client op.Client, session *command.OIDCSession, state, projectID string, projectRoleAssertion bool) (_ *oidc.AccessTokenResponse, err error) {
getUserInfo := s.getUserInfoOnce(session.UserID, projectID, projectRoleAssertion, session.Scope)
func (s *Server) accessTokenResponseFromSession(ctx context.Context, client op.Client, session *command.OIDCSession, state, projectID string, projectRoleAssertion, accessTokenRoleAssertion, idTokenRoleAssertion, userInfoAssertion bool) (_ *oidc.AccessTokenResponse, err error) {
getUserInfo := s.getUserInfo(session.UserID, projectID, projectRoleAssertion, userInfoAssertion, session.Scope)
getSigner := s.getSignerOnce()
resp := &oidc.AccessTokenResponse{
@ -43,7 +43,7 @@ func (s *Server) accessTokenResponseFromSession(ctx context.Context, client op.C
// If the session does not have a token ID, it is an implicit ID-Token only response.
if session.TokenID != "" {
if client.AccessTokenType() == op.AccessTokenTypeJWT {
resp.AccessToken, err = s.createJWT(ctx, client, session, getUserInfo, getSigner)
resp.AccessToken, err = s.createJWT(ctx, client, session, getUserInfo, accessTokenRoleAssertion, getSigner)
} else {
resp.AccessToken, err = op.CreateBearerToken(session.TokenID, session.UserID, s.opCrypto)
}
@ -53,7 +53,7 @@ func (s *Server) accessTokenResponseFromSession(ctx context.Context, client op.C
}
if slices.Contains(session.Scope, oidc.ScopeOpenID) {
resp.IDToken, _, err = s.createIDToken(ctx, client, getUserInfo, getSigner, session.SessionID, resp.AccessToken, session.Audience, session.AuthMethods, session.AuthTime, session.Nonce, session.Actor)
resp.IDToken, _, err = s.createIDToken(ctx, client, getUserInfo, idTokenRoleAssertion, getSigner, session.SessionID, resp.AccessToken, session.Audience, session.AuthMethods, session.AuthTime, session.Nonce, session.Actor)
}
return resp, err
}
@ -92,31 +92,22 @@ func (s *Server) getSignerOnce() signerFunc {
}
// userInfoFunc is a getter function that allows add-hoc retrieval of a user.
type userInfoFunc func(ctx context.Context) (*oidc.UserInfo, error)
type userInfoFunc func(ctx context.Context, roleAssertion bool, triggerType domain.TriggerType) (*oidc.UserInfo, error)
// getUserInfoOnce returns a function which retrieves userinfo from the database once.
// Repeated calls of the returned function return the same results.
func (s *Server) getUserInfoOnce(userID, projectID string, projectRoleAssertion bool, scope []string) userInfoFunc {
var (
once sync.Once
userInfo *oidc.UserInfo
err error
)
return func(ctx context.Context) (*oidc.UserInfo, error) {
once.Do(func() {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
userInfo, err = s.userInfo(ctx, userID, scope, projectID, projectRoleAssertion, false)
})
return userInfo, err
// getUserInfo returns a function which retrieves userinfo from the database once.
// However, each time, role claims are asserted and also action flows will trigger.
func (s *Server) getUserInfo(userID, projectID string, projectRoleAssertion, userInfoAssertion bool, scope []string) userInfoFunc {
userInfo := s.userInfo(userID, scope, projectID, projectRoleAssertion, userInfoAssertion, false)
return func(ctx context.Context, roleAssertion bool, triggerType domain.TriggerType) (*oidc.UserInfo, error) {
return userInfo(ctx, roleAssertion, triggerType)
}
}
func (*Server) createIDToken(ctx context.Context, client op.Client, getUserInfo userInfoFunc, getSigningKey signerFunc, sessionID, accessToken string, audience []string, authMethods []domain.UserAuthMethodType, authTime time.Time, nonce string, actor *domain.TokenActor) (idToken string, exp uint64, err error) {
func (*Server) createIDToken(ctx context.Context, client op.Client, getUserInfo userInfoFunc, roleAssertion bool, getSigningKey signerFunc, sessionID, accessToken string, audience []string, authMethods []domain.UserAuthMethodType, authTime time.Time, nonce string, actor *domain.TokenActor) (idToken string, exp uint64, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
userInfo, err := getUserInfo(ctx)
userInfo, err := getUserInfo(ctx, roleAssertion, domain.TriggerTypePreUserinfoCreation)
if err != nil {
return "", 0, err
}
@ -156,11 +147,11 @@ func timeToOIDCExpiresIn(exp time.Time) uint64 {
return uint64(time.Until(exp) / time.Second)
}
func (*Server) createJWT(ctx context.Context, client op.Client, session *command.OIDCSession, getUserInfo userInfoFunc, getSigner signerFunc) (_ string, err error) {
func (s *Server) createJWT(ctx context.Context, client op.Client, session *command.OIDCSession, getUserInfo userInfoFunc, assertRoles bool, getSigner signerFunc) (_ string, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
userInfo, err := getUserInfo(ctx)
userInfo, err := getUserInfo(ctx, assertRoles, domain.TriggerTypePreAccessTokenCreation)
if err != nil {
return "", err
}

View File

@ -47,5 +47,5 @@ func (s *Server) ClientCredentialsExchange(ctx context.Context, r *op.ClientRequ
false,
)
return response(s.accessTokenResponseFromSession(ctx, client, session, "", "", false))
return response(s.accessTokenResponseFromSession(ctx, client, session, "", "", false, true, false, false))
}

View File

@ -4,6 +4,7 @@ package oidc_test
import (
"testing"
"time"
"github.com/brianvoe/gofakeit/v6"
"github.com/stretchr/testify/assert"
@ -18,10 +19,13 @@ import (
)
func TestServer_ClientCredentialsExchange(t *testing.T) {
userID, clientID, clientSecret, err := Tester.CreateOIDCCredentialsClient(CTX)
machine, name, clientID, clientSecret, err := Tester.CreateOIDCCredentialsClient(CTX)
require.NoError(t, err)
type claims struct {
name string
username string
updated time.Time
resourceOwnerID any
resourceOwnerName any
resourceOwnerPrimaryDomain any
@ -78,6 +82,17 @@ func TestServer_ClientCredentialsExchange(t *testing.T) {
clientSecret: clientSecret,
scope: []string{oidc.ScopeOpenID},
},
{
name: "openid, profile, email",
clientID: clientID,
clientSecret: clientSecret,
scope: []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail},
wantClaims: claims{
name: name,
username: name,
updated: machine.GetDetails().GetChangeDate().AsTime(),
},
},
{
name: "org id and domain scope",
clientID: clientID,
@ -132,12 +147,20 @@ func TestServer_ClientCredentialsExchange(t *testing.T) {
}
require.NoError(t, err)
require.NotNil(t, tokens)
userinfo, err := rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, oidc.BearerToken, userID, provider)
userinfo, err := rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, oidc.BearerToken, machine.GetUserId(), provider)
require.NoError(t, err)
assert.Equal(t, tt.wantClaims.resourceOwnerID, userinfo.Claims[oidc_api.ClaimResourceOwnerID])
assert.Equal(t, tt.wantClaims.resourceOwnerName, userinfo.Claims[oidc_api.ClaimResourceOwnerName])
assert.Equal(t, tt.wantClaims.resourceOwnerPrimaryDomain, userinfo.Claims[oidc_api.ClaimResourceOwnerPrimaryDomain])
assert.Equal(t, tt.wantClaims.orgDomain, userinfo.Claims[domain.OrgDomainPrimaryClaim])
assert.Equal(t, tt.wantClaims.name, userinfo.Name)
assert.Equal(t, tt.wantClaims.username, userinfo.PreferredUsername)
assertOIDCTime(t, userinfo.UpdatedAt, tt.wantClaims.updated)
assert.Empty(t, userinfo.UserInfoProfile.FamilyName)
assert.Empty(t, userinfo.UserInfoProfile.GivenName)
assert.Empty(t, userinfo.UserInfoEmail)
assert.Empty(t, userinfo.UserInfoPhone)
assert.Empty(t, userinfo.Address)
})
}
}

View File

@ -49,7 +49,7 @@ func (s *Server) CodeExchange(ctx context.Context, r *op.ClientRequest[oidc.Acce
if err != nil {
return nil, err
}
return response(s.accessTokenResponseFromSession(ctx, client, session, state, client.client.ProjectID, client.client.ProjectRoleAssertion))
return response(s.accessTokenResponseFromSession(ctx, client, session, state, client.client.ProjectID, client.client.ProjectRoleAssertion, client.client.AccessTokenRoleAssertion, client.client.IDTokenRoleAssertion, client.client.IDTokenUserinfoAssertion))
}
// codeExchangeV1 creates a v2 token from a v1 auth request.

View File

@ -26,7 +26,7 @@ func (s *Server) DeviceToken(ctx context.Context, r *op.ClientRequest[oidc.Devic
}
session, err := s.command.CreateOIDCSessionFromDeviceAuth(ctx, r.Data.DeviceCode)
if err == nil {
return response(s.accessTokenResponseFromSession(ctx, client, session, "", client.client.ProjectID, client.client.ProjectRoleAssertion))
return response(s.accessTokenResponseFromSession(ctx, client, session, "", client.client.ProjectID, client.client.ProjectRoleAssertion, client.client.AccessTokenRoleAssertion, client.client.IDTokenRoleAssertion, client.client.IDTokenUserinfoAssertion))
}
if errors.Is(err, context.DeadlineExceeded) {
return nil, oidc.ErrSlowDown().WithParent(err)

View File

@ -218,7 +218,7 @@ func validateTokenExchangeAudience(requestedAudience, subjectAudience, actorAudi
// Both tokens may point to the same object (subjectToken) in case of a regular Token Exchange.
// When the subject and actor Tokens point to different objects, the new tokens will be for impersonation / delegation.
func (s *Server) createExchangeTokens(ctx context.Context, tokenType oidc.TokenType, client *Client, subjectToken, actorToken *exchangeToken, audience, scopes []string) (_ *oidc.TokenExchangeResponse, err error) {
getUserInfo := s.getUserInfoOnce(subjectToken.userID, client.client.ProjectID, client.client.ProjectRoleAssertion, scopes)
getUserInfo := s.getUserInfo(subjectToken.userID, client.client.ProjectID, client.client.ProjectRoleAssertion, client.IDTokenUserinfoClaimsAssertion(), scopes)
getSigner := s.getSignerOnce()
resp := &oidc.TokenExchangeResponse{
@ -240,12 +240,12 @@ func (s *Server) createExchangeTokens(ctx context.Context, tokenType oidc.TokenT
resp.IssuedTokenType = oidc.AccessTokenType
case oidc.JWTTokenType:
resp.AccessToken, resp.RefreshToken, resp.ExpiresIn, err = s.createExchangeJWT(ctx, client, getUserInfo, getSigner, subjectToken.userID, subjectToken.resourceOwner, audience, scopes, actorToken.authMethods, actorToken.authTime, subjectToken.preferredLanguage, reason, actor)
resp.AccessToken, resp.RefreshToken, resp.ExpiresIn, err = s.createExchangeJWT(ctx, client, getUserInfo, client.client.AccessTokenRoleAssertion, getSigner, subjectToken.userID, subjectToken.resourceOwner, audience, scopes, actorToken.authMethods, actorToken.authTime, subjectToken.preferredLanguage, reason, actor)
resp.TokenType = oidc.BearerToken
resp.IssuedTokenType = oidc.JWTTokenType
case oidc.IDTokenType:
resp.AccessToken, resp.ExpiresIn, err = s.createIDToken(ctx, client, getUserInfo, getSigner, "", resp.AccessToken, audience, actorToken.authMethods, actorToken.authTime, "", actor)
resp.AccessToken, resp.ExpiresIn, err = s.createIDToken(ctx, client, getUserInfo, client.client.IDTokenRoleAssertion, getSigner, "", resp.AccessToken, audience, actorToken.authMethods, actorToken.authTime, "", actor)
resp.TokenType = TokenTypeNA
resp.IssuedTokenType = oidc.IDTokenType
@ -259,7 +259,7 @@ func (s *Server) createExchangeTokens(ctx context.Context, tokenType oidc.TokenT
}
if slices.Contains(scopes, oidc.ScopeOpenID) && tokenType != oidc.IDTokenType {
resp.IDToken, _, err = s.createIDToken(ctx, client, getUserInfo, getSigner, sessionID, resp.AccessToken, audience, actorToken.authMethods, actorToken.authTime, "", actor)
resp.IDToken, _, err = s.createIDToken(ctx, client, getUserInfo, client.client.IDTokenRoleAssertion, getSigner, sessionID, resp.AccessToken, audience, actorToken.authMethods, actorToken.authTime, "", actor)
if err != nil {
return nil, err
}
@ -313,6 +313,7 @@ func (s *Server) createExchangeJWT(
ctx context.Context,
client *Client,
getUserInfo userInfoFunc,
roleAssertion bool,
getSigner signerFunc,
userID,
resourceOwner string,
@ -342,7 +343,7 @@ func (s *Server) createExchangeJWT(
actor,
slices.Contains(scope, oidc.ScopeOfflineAccess),
)
accessToken, err = s.createJWT(ctx, client, session, getUserInfo, getSigner)
accessToken, err = s.createJWT(ctx, client, session, getUserInfo, roleAssertion, getSigner)
if err != nil {
return "", "", 0, err
}

View File

@ -54,7 +54,7 @@ func (s *Server) JWTProfile(ctx context.Context, r *op.Request[oidc.JWTProfileGr
nil,
false,
)
return response(s.accessTokenResponseFromSession(ctx, client, session, "", "", false))
return response(s.accessTokenResponseFromSession(ctx, client, session, "", "", false, true, false, false))
}
func (s *Server) verifyJWTProfile(ctx context.Context, req *oidc.JWTProfileGrantRequest) (user *query.User, tokenRequest *oidc.JWTTokenRequest, err error) {

View File

@ -4,6 +4,7 @@ package oidc_test
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -16,10 +17,13 @@ import (
)
func TestServer_JWTProfile(t *testing.T) {
userID, keyData, err := Tester.CreateOIDCJWTProfileClient(CTX)
user, name, keyData, err := Tester.CreateOIDCJWTProfileClient(CTX)
require.NoError(t, err)
type claims struct {
name string
username string
updated time.Time
resourceOwnerID any
resourceOwnerName any
resourceOwnerPrimaryDomain any
@ -37,6 +41,16 @@ func TestServer_JWTProfile(t *testing.T) {
keyData: keyData,
scope: []string{oidc.ScopeOpenID},
},
{
name: "openid, profile, email",
keyData: keyData,
scope: []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail},
wantClaims: claims{
name: name,
username: name,
updated: user.GetDetails().GetChangeDate().AsTime(),
},
},
{
name: "org id and domain scope",
keyData: keyData,
@ -92,12 +106,20 @@ func TestServer_JWTProfile(t *testing.T) {
provider, err := rp.NewRelyingPartyOIDC(CTX, Tester.OIDCIssuer(), "", "", redirectURI, tt.scope)
require.NoError(t, err)
userinfo, err := rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, oidc.BearerToken, userID, provider)
userinfo, err := rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, oidc.BearerToken, user.GetUserId(), provider)
require.NoError(t, err)
assert.Equal(t, tt.wantClaims.resourceOwnerID, userinfo.Claims[oidc_api.ClaimResourceOwnerID])
assert.Equal(t, tt.wantClaims.resourceOwnerName, userinfo.Claims[oidc_api.ClaimResourceOwnerName])
assert.Equal(t, tt.wantClaims.resourceOwnerPrimaryDomain, userinfo.Claims[oidc_api.ClaimResourceOwnerPrimaryDomain])
assert.Equal(t, tt.wantClaims.orgDomain, userinfo.Claims[domain.OrgDomainPrimaryClaim])
assert.Equal(t, tt.wantClaims.name, userinfo.Name)
assert.Equal(t, tt.wantClaims.username, userinfo.PreferredUsername)
assertOIDCTime(t, userinfo.UpdatedAt, tt.wantClaims.updated)
assert.Empty(t, userinfo.UserInfoProfile.FamilyName)
assert.Empty(t, userinfo.UserInfoProfile.GivenName)
assert.Empty(t, userinfo.UserInfoEmail)
assert.Empty(t, userinfo.UserInfoPhone)
assert.Empty(t, userinfo.Address)
})
}
}

View File

@ -28,7 +28,7 @@ func (s *Server) RefreshToken(ctx context.Context, r *op.ClientRequest[oidc.Refr
session, err := s.command.ExchangeOIDCSessionRefreshAndAccessToken(ctx, r.Data.RefreshToken, r.Data.Scopes, refreshTokenComplianceChecker())
if err == nil {
return response(s.accessTokenResponseFromSession(ctx, client, session, "", client.client.ProjectID, client.client.ProjectRoleAssertion))
return response(s.accessTokenResponseFromSession(ctx, client, session, "", client.client.ProjectID, client.client.ProjectRoleAssertion, client.client.AccessTokenRoleAssertion, client.client.IDTokenRoleAssertion, client.client.IDTokenUserinfoAssertion))
} else if errors.Is(err, zerrors.ThrowPreconditionFailed(nil, "OIDCS-JOI23", "Errors.OIDCSession.RefreshTokenInvalid")) {
// We try again for v1 tokens when we encountered specific parsing error
return s.refreshTokenV1(ctx, client, r)
@ -78,7 +78,7 @@ func (s *Server) refreshTokenV1(ctx context.Context, client *Client, r *op.Clien
return nil, err
}
return response(s.accessTokenResponseFromSession(ctx, client, session, "", client.client.ProjectID, client.client.ProjectRoleAssertion))
return response(s.accessTokenResponseFromSession(ctx, client, session, "", client.client.ProjectID, client.client.ProjectRoleAssertion, client.client.AccessTokenRoleAssertion, client.client.IDTokenRoleAssertion, client.client.IDTokenUserinfoAssertion))
}
// refreshTokenComplianceChecker validates that the requested scope is a subset of the original auth request scope.

View File

@ -8,6 +8,7 @@ import (
"net/http"
"slices"
"strings"
"sync"
"github.com/dop251/goja"
"github.com/zitadel/logging"
@ -55,7 +56,14 @@ func (s *Server) UserInfo(ctx context.Context, r *op.Request[oidc.UserInfoReques
}
}
userInfo, err := s.userInfo(ctx, token.userID, token.scope, projectID, assertion, false)
userInfo, err := s.userInfo(
token.userID,
token.scope,
projectID,
assertion,
true,
false,
)(ctx, true, domain.TriggerTypePreUserinfoCreation)
if err != nil {
return nil, err
}
@ -66,24 +74,44 @@ func (s *Server) UserInfo(ctx context.Context, r *op.Request[oidc.UserInfoReques
// The returned UserInfo contains standard and reserved claims, documented
// here: https://zitadel.com/docs/apis/openidoauth/claims.
//
// User information is only retrieved once from the database.
// However, each time, role claims are asserted and also action flows will trigger.
//
// projectID is an optional parameter which defines the default audience when there are any (or all) role claims requested.
// projectRoleAssertion sets the default of returning all project roles, only if no specific roles were requested in the scope.
// roleAssertion decides whether the roles will be returned (in the token or response)
// userInfoAssertion decides whether the user information (profile data like name, email, ...) are returned
//
// currentProjectOnly can be set to use the current project ID only and ignore the audience from the scope.
// It should be set in cases where the client doesn't need to know roles outside its own project,
// for example an introspection client.
func (s *Server) userInfo(ctx context.Context, userID string, scope []string, projectID string, projectRoleAssertion, currentProjectOnly bool) (_ *oidc.UserInfo, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
func (s *Server) userInfo(
userID string,
scope []string,
projectID string,
projectRoleAssertion, userInfoAssertion, currentProjectOnly bool,
) func(ctx context.Context, roleAssertion bool, triggerType domain.TriggerType) (_ *oidc.UserInfo, err error) {
var (
once sync.Once
userInfo *oidc.UserInfo
qu *query.OIDCUserInfo
roleAudience, requestedRoles []string
)
return func(ctx context.Context, roleAssertion bool, triggerType domain.TriggerType) (_ *oidc.UserInfo, err error) {
once.Do(func() {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
roleAudience, requestedRoles := prepareRoles(ctx, scope, projectID, projectRoleAssertion, currentProjectOnly)
qu, err := s.query.GetOIDCUserInfo(ctx, userID, roleAudience)
if err != nil {
return nil, err
roleAudience, requestedRoles = prepareRoles(ctx, scope, projectID, projectRoleAssertion, currentProjectOnly)
qu, err = s.query.GetOIDCUserInfo(ctx, userID, roleAudience)
if err != nil {
return
}
userInfo = userInfoToOIDC(qu, userInfoAssertion, scope, s.assetAPIPrefix(ctx))
})
userInfoWithRoles := assertRoles(projectID, qu, roleAudience, requestedRoles, roleAssertion, userInfo)
return userInfoWithRoles, s.userinfoFlows(ctx, qu, userInfoWithRoles, triggerType)
}
userInfo := userInfoToOIDC(projectID, qu, scope, roleAudience, requestedRoles, s.assetAPIPrefix(ctx))
return userInfo, s.userinfoFlows(ctx, qu, userInfo)
}
// prepareRoles scans the requested scopes and builds the requested roles
@ -120,20 +148,32 @@ func prepareRoles(ctx context.Context, scope []string, projectID string, project
return roleAudience, requestedRoles
}
func userInfoToOIDC(projectID string, user *query.OIDCUserInfo, scope, roleAudience, requestedRoles []string, assetPrefix string) *oidc.UserInfo {
func userInfoToOIDC(user *query.OIDCUserInfo, userInfoAssertion bool, scope []string, assetPrefix string) *oidc.UserInfo {
out := new(oidc.UserInfo)
for _, s := range scope {
switch s {
case oidc.ScopeOpenID:
out.Subject = user.User.ID
case oidc.ScopeEmail:
if !userInfoAssertion {
continue
}
out.UserInfoEmail = userInfoEmailToOIDC(user.User)
case oidc.ScopeProfile:
if !userInfoAssertion {
continue
}
out.UserInfoProfile = userInfoProfileToOidc(user.User, assetPrefix)
case oidc.ScopePhone:
if !userInfoAssertion {
continue
}
out.UserInfoPhone = userInfoPhoneToOIDC(user.User)
case oidc.ScopeAddress:
//TODO: handle address for human users as soon as implemented
if !userInfoAssertion {
continue
}
// TODO: handle address for human users as soon as implemented
case ScopeUserMetaData:
setUserInfoMetadata(user.Metadata, out)
case ScopeResourceOwner:
@ -148,12 +188,19 @@ func userInfoToOIDC(projectID string, user *query.OIDCUserInfo, scope, roleAudie
}
}
}
return out
}
func assertRoles(projectID string, user *query.OIDCUserInfo, roleAudience, requestedRoles []string, assertion bool, info *oidc.UserInfo) *oidc.UserInfo {
if !assertion {
return info
}
userInfo := *info
// prevent returning obtained grants if none where requested
if (projectID != "" && len(requestedRoles) > 0) || len(roleAudience) > 0 {
setUserInfoRoleClaims(out, newProjectRoles(projectID, user.UserGrants, requestedRoles))
setUserInfoRoleClaims(&userInfo, newProjectRoles(projectID, user.UserGrants, requestedRoles))
}
return out
return &userInfo
}
func userInfoEmailToOIDC(user *query.User) oidc.UserInfoEmail {
@ -230,11 +277,11 @@ func setUserInfoRoleClaims(userInfo *oidc.UserInfo, roles *projectsRoles) {
}
}
func (s *Server) userinfoFlows(ctx context.Context, qu *query.OIDCUserInfo, userInfo *oidc.UserInfo) (err error) {
func (s *Server) userinfoFlows(ctx context.Context, qu *query.OIDCUserInfo, userInfo *oidc.UserInfo, triggerType domain.TriggerType) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
queriedActions, err := s.query.GetActiveActionsByFlowAndTriggerType(ctx, domain.FlowTypeCustomiseToken, domain.TriggerTypePreUserinfoCreation, qu.User.ResourceOwner)
queriedActions, err := s.query.GetActiveActionsByFlowAndTriggerType(ctx, domain.FlowTypeCustomiseToken, triggerType, qu.User.ResourceOwner)
if err != nil {
return err
}

View File

@ -231,9 +231,9 @@ func TestServer_UserInfo_Issue6662(t *testing.T) {
project, err := Tester.CreateProject(CTX)
projectID := project.GetId()
require.NoError(t, err)
userID, clientID, clientSecret, err := Tester.CreateOIDCCredentialsClient(CTX)
user, _, clientID, clientSecret, err := Tester.CreateOIDCCredentialsClient(CTX)
require.NoError(t, err)
addProjectRolesGrants(t, userID, projectID, roleFoo, roleBar)
addProjectRolesGrants(t, user.GetUserId(), projectID, roleFoo, roleBar)
scope := []string{oidc.ScopeProfile, oidc.ScopeOpenID, oidc.ScopeEmail, oidc.ScopeOfflineAccess,
oidc_api.ScopeProjectRolePrefix + roleFoo,
@ -245,7 +245,7 @@ func TestServer_UserInfo_Issue6662(t *testing.T) {
tokens, err := rp.ClientCredentials(CTX, provider, nil)
require.NoError(t, err)
userinfo, err := rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, tokens.TokenType, userID, provider)
userinfo, err := rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, tokens.TokenType, user.GetUserId(), provider)
require.NoError(t, err)
assertProjectRoleClaims(t, projectID, userinfo.Claims, false, roleFoo)
}
@ -291,7 +291,7 @@ func getTokens(t *testing.T, clientID string, scope []string) *oidc.Tokens[*oidc
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err)
assertTokens(t, tokens, true)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
return tokens
}

View File

@ -3,7 +3,6 @@ package oidc
import (
"context"
"encoding/base64"
"fmt"
"testing"
"time"
@ -267,11 +266,9 @@ func Test_userInfoToOIDC(t *testing.T) {
}
type args struct {
projectID string
user *query.OIDCUserInfo
scope []string
roleAudience []string
requestedRoles []string
user *query.OIDCUserInfo
userInfoAssertion bool
scope []string
}
tests := []struct {
name string
@ -281,25 +278,22 @@ func Test_userInfoToOIDC(t *testing.T) {
{
name: "human, empty",
args: args{
projectID: "project1",
user: humanUserInfo,
user: humanUserInfo,
},
want: &oidc.UserInfo{},
},
{
name: "machine, empty",
args: args{
projectID: "project1",
user: machineUserInfo,
user: machineUserInfo,
},
want: &oidc.UserInfo{},
},
{
name: "human, scope openid",
args: args{
projectID: "project1",
user: humanUserInfo,
scope: []string{oidc.ScopeOpenID},
user: humanUserInfo,
scope: []string{oidc.ScopeOpenID},
},
want: &oidc.UserInfo{
Subject: "human1",
@ -308,20 +302,19 @@ func Test_userInfoToOIDC(t *testing.T) {
{
name: "machine, scope openid",
args: args{
projectID: "project1",
user: machineUserInfo,
scope: []string{oidc.ScopeOpenID},
user: machineUserInfo,
scope: []string{oidc.ScopeOpenID},
},
want: &oidc.UserInfo{
Subject: "machine1",
},
},
{
name: "human, scope email",
name: "human, scope email, profileInfoAssertion",
args: args{
projectID: "project1",
user: humanUserInfo,
scope: []string{oidc.ScopeEmail},
user: humanUserInfo,
userInfoAssertion: true,
scope: []string{oidc.ScopeEmail},
},
want: &oidc.UserInfo{
UserInfoEmail: oidc.UserInfoEmail{
@ -331,22 +324,29 @@ func Test_userInfoToOIDC(t *testing.T) {
},
},
{
name: "machine, scope email",
name: "human, scope email",
args: args{
projectID: "project1",
user: machineUserInfo,
scope: []string{oidc.ScopeEmail},
user: humanUserInfo,
scope: []string{oidc.ScopeEmail},
},
want: &oidc.UserInfo{},
},
{
name: "machine, scope email, profileInfoAssertion",
args: args{
user: machineUserInfo,
scope: []string{oidc.ScopeEmail},
},
want: &oidc.UserInfo{
UserInfoEmail: oidc.UserInfoEmail{},
},
},
{
name: "human, scope profile",
name: "human, scope profile, profileInfoAssertion",
args: args{
projectID: "project1",
user: humanUserInfo,
scope: []string{oidc.ScopeProfile},
user: humanUserInfo,
userInfoAssertion: true,
scope: []string{oidc.ScopeProfile},
},
want: &oidc.UserInfo{
UserInfoProfile: oidc.UserInfoProfile{
@ -363,11 +363,11 @@ func Test_userInfoToOIDC(t *testing.T) {
},
},
{
name: "machine, scope profile",
name: "machine, scope profile, profileInfoAssertion",
args: args{
projectID: "project1",
user: machineUserInfo,
scope: []string{oidc.ScopeProfile},
user: machineUserInfo,
userInfoAssertion: true,
scope: []string{oidc.ScopeProfile},
},
want: &oidc.UserInfo{
UserInfoProfile: oidc.UserInfoProfile{
@ -378,11 +378,19 @@ func Test_userInfoToOIDC(t *testing.T) {
},
},
{
name: "human, scope phone",
name: "machine, scope profile",
args: args{
projectID: "project1",
user: humanUserInfo,
scope: []string{oidc.ScopePhone},
user: machineUserInfo,
scope: []string{oidc.ScopeProfile},
},
want: &oidc.UserInfo{},
},
{
name: "human, scope phone, profileInfoAssertion",
args: args{
user: humanUserInfo,
userInfoAssertion: true,
scope: []string{oidc.ScopePhone},
},
want: &oidc.UserInfo{
UserInfoPhone: oidc.UserInfoPhone{
@ -391,12 +399,19 @@ func Test_userInfoToOIDC(t *testing.T) {
},
},
},
{
name: "human, scope phone",
args: args{
user: humanUserInfo,
scope: []string{oidc.ScopePhone},
},
want: &oidc.UserInfo{},
},
{
name: "machine, scope phone",
args: args{
projectID: "project1",
user: machineUserInfo,
scope: []string{oidc.ScopePhone},
user: machineUserInfo,
scope: []string{oidc.ScopePhone},
},
want: &oidc.UserInfo{
UserInfoPhone: oidc.UserInfoPhone{},
@ -405,9 +420,8 @@ func Test_userInfoToOIDC(t *testing.T) {
{
name: "human, scope metadata",
args: args{
projectID: "project1",
user: humanUserInfo,
scope: []string{ScopeUserMetaData},
user: humanUserInfo,
scope: []string{ScopeUserMetaData},
},
want: &oidc.UserInfo{
Claims: map[string]any{
@ -421,18 +435,16 @@ func Test_userInfoToOIDC(t *testing.T) {
{
name: "machine, scope metadata, none found",
args: args{
projectID: "project1",
user: machineUserInfo,
scope: []string{ScopeUserMetaData},
user: machineUserInfo,
scope: []string{ScopeUserMetaData},
},
want: &oidc.UserInfo{},
},
{
name: "machine, scope resource owner",
args: args{
projectID: "project1",
user: machineUserInfo,
scope: []string{ScopeResourceOwner},
user: machineUserInfo,
scope: []string{ScopeResourceOwner},
},
want: &oidc.UserInfo{
Claims: map[string]any{
@ -445,9 +457,8 @@ func Test_userInfoToOIDC(t *testing.T) {
{
name: "human, scope org primary domain prefix",
args: args{
projectID: "project1",
user: humanUserInfo,
scope: []string{domain.OrgDomainPrimaryScope + "foo.com"},
user: humanUserInfo,
scope: []string{domain.OrgDomainPrimaryScope + "foo.com"},
},
want: &oidc.UserInfo{
Claims: map[string]any{
@ -458,9 +469,8 @@ func Test_userInfoToOIDC(t *testing.T) {
{
name: "machine, scope org id",
args: args{
projectID: "project1",
user: machineUserInfo,
scope: []string{domain.OrgIDScope + "orgID"},
user: machineUserInfo,
scope: []string{domain.OrgIDScope + "orgID"},
},
want: &oidc.UserInfo{
Claims: map[string]any{
@ -471,50 +481,11 @@ func Test_userInfoToOIDC(t *testing.T) {
},
},
},
{
name: "human, roleAudience",
args: args{
projectID: "project1",
user: humanUserInfo,
roleAudience: []string{"project1"},
},
want: &oidc.UserInfo{
Claims: map[string]any{
ClaimProjectRoles: projectRoles{
"role1": {"orgID": "orgDomain"},
"role2": {"orgID": "orgDomain"},
},
fmt.Sprintf(ClaimProjectRolesFormat, "project1"): projectRoles{
"role1": {"orgID": "orgDomain"},
"role2": {"orgID": "orgDomain"},
},
},
},
},
{
name: "human, requested roles",
args: args{
projectID: "project1",
user: humanUserInfo,
roleAudience: []string{"project1"},
requestedRoles: []string{"role2"},
},
want: &oidc.UserInfo{
Claims: map[string]any{
ClaimProjectRoles: projectRoles{
"role2": {"orgID": "orgDomain"},
},
fmt.Sprintf(ClaimProjectRolesFormat, "project1"): projectRoles{
"role2": {"orgID": "orgDomain"},
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assetPrefix := "https://foo.com/assets"
got := userInfoToOIDC(tt.args.projectID, tt.args.user, tt.args.scope, tt.args.roleAudience, tt.args.requestedRoles, assetPrefix)
got := userInfoToOIDC(tt.args.user, tt.args.userInfoAssertion, tt.args.scope, assetPrefix)
assert.Equal(t, tt.want, got)
})
}

View File

@ -281,42 +281,42 @@ func CheckRedirect(req *http.Request) (*url.URL, error) {
return resp.Location()
}
func (s *Tester) CreateOIDCCredentialsClient(ctx context.Context) (userID, clientID, clientSecret string, err error) {
name := gofakeit.Username()
user, err := s.Client.Mgmt.AddMachineUser(ctx, &management.AddMachineUserRequest{
func (s *Tester) CreateOIDCCredentialsClient(ctx context.Context) (machine *management.AddMachineUserResponse, name, clientID, clientSecret string, err error) {
name = gofakeit.Username()
machine, err = s.Client.Mgmt.AddMachineUser(ctx, &management.AddMachineUserRequest{
Name: name,
UserName: name,
AccessTokenType: user.AccessTokenType_ACCESS_TOKEN_TYPE_JWT,
})
if err != nil {
return "", "", "", err
return nil, "", "", "", err
}
secret, err := s.Client.Mgmt.GenerateMachineSecret(ctx, &management.GenerateMachineSecretRequest{
UserId: user.GetUserId(),
UserId: machine.GetUserId(),
})
if err != nil {
return "", "", "", err
return nil, "", "", "", err
}
return user.GetUserId(), secret.GetClientId(), secret.GetClientSecret(), nil
return machine, name, secret.GetClientId(), secret.GetClientSecret(), nil
}
func (s *Tester) CreateOIDCJWTProfileClient(ctx context.Context) (userID string, keyData []byte, err error) {
name := gofakeit.Username()
user, err := s.Client.Mgmt.AddMachineUser(ctx, &management.AddMachineUserRequest{
func (s *Tester) CreateOIDCJWTProfileClient(ctx context.Context) (machine *management.AddMachineUserResponse, name string, keyData []byte, err error) {
name = gofakeit.Username()
machine, err = s.Client.Mgmt.AddMachineUser(ctx, &management.AddMachineUserRequest{
Name: name,
UserName: name,
AccessTokenType: user.AccessTokenType_ACCESS_TOKEN_TYPE_JWT,
})
if err != nil {
return "", nil, err
return nil, "", nil, err
}
keyResp, err := s.Client.Mgmt.AddMachineKey(ctx, &management.AddMachineKeyRequest{
UserId: user.GetUserId(),
UserId: machine.GetUserId(),
Type: authn.KeyType_KEY_TYPE_JSON,
ExpirationDate: timestamppb.New(time.Now().Add(time.Hour)),
})
if err != nil {
return "", nil, err
return nil, "", nil, err
}
return user.GetUserId(), keyResp.GetKeyDetails(), nil
return machine, name, keyResp.GetKeyDetails(), nil
}