From f065b42a97ed2372913224e6f2580a7318a9b7b5 Mon Sep 17 00:00:00 2001 From: Livio Spring Date: Fri, 31 May 2024 12:10:18 +0200 Subject: [PATCH] fix(oidc): respect role assertion and idTokenInfo flags and trigger preAccessToken trigger (#8046) # Which Problems Are Solved After deployment of 2.53.x, customers noted that the roles claims where always present in the tokens even if the corresponding option on the client (accessTokenRoleAssertion, idTokenRoleAsseriton) was disabled. Only the project flag (assertRolesOnAuthentication) would be considered. Further it was noted, that the action on the preAccessTokenCreation trigger would not be executed. Additionally, while testing those issues we found out, that the user information (name, givenname, family name, ...) where always present in the id_token even if the option (idTokenUserInfo) was not enabled. # How the Problems Are Solved - The `getUserinfoOnce` which was used for access and id_tokens is refactored to `getUserInfo` and no longer only queries the info once from the database, but still provides a mechanism to be reused for access and id_token where the corresponding `roleAssertion` and action `triggerType` can be passed. - `userInfo` on the other hand now directly makes sure the information is only queried once from the database. Role claims are asserted every time and action triggers are executed on every call. - `userInfo` now also checks if the profile information need to be returned. # Additional Changes None. # Additional Context - relates to #7822 - reported by customers --- internal/api/oidc/auth_request.go | 4 +- .../api/oidc/auth_request_integration_test.go | 33 ++-- internal/api/oidc/client_integration_test.go | 4 +- internal/api/oidc/introspect.go | 10 +- internal/api/oidc/oidc_integration_test.go | 16 +- internal/api/oidc/token.go | 39 ++--- internal/api/oidc/token_client_credentials.go | 2 +- ...ken_client_credentials_integration_test.go | 27 ++- internal/api/oidc/token_code.go | 2 +- internal/api/oidc/token_device.go | 2 +- internal/api/oidc/token_exchange.go | 11 +- internal/api/oidc/token_jwt_profile.go | 2 +- .../token_jwt_profile_integration_test.go | 26 ++- internal/api/oidc/token_refresh.go | 4 +- internal/api/oidc/userinfo.go | 81 +++++++-- .../api/oidc/userinfo_integration_test.go | 8 +- internal/api/oidc/userinfo_test.go | 159 +++++++----------- internal/integration/oidc.go | 28 +-- 18 files changed, 263 insertions(+), 195 deletions(-) diff --git a/internal/api/oidc/auth_request.go b/internal/api/oidc/auth_request.go index 1abe59dc88..5d23a8bd98 100644 --- a/internal/api/oidc/auth_request.go +++ b/internal/api/oidc/auth_request.go @@ -471,7 +471,7 @@ func (s *Server) CreateTokenCallbackURL(ctx context.Context, req op.AuthRequest) if err != nil { return "", err } - resp, err := s.accessTokenResponseFromSession(ctx, client, session, state, client.client.ProjectID, client.client.ProjectRoleAssertion) + resp, err := s.accessTokenResponseFromSession(ctx, client, session, state, client.client.ProjectID, client.client.ProjectRoleAssertion, client.client.AccessTokenRoleAssertion, client.client.IDTokenRoleAssertion, client.client.IDTokenUserinfoAssertion) if err != nil { return "", err } @@ -563,7 +563,7 @@ func (s *Server) authResponseToken(authReq *AuthRequest, authorizer op.Authorize op.AuthRequestError(w, r, authReq, err, authorizer) return err } - resp, err := s.accessTokenResponseFromSession(ctx, client, session, authReq.GetState(), client.client.ProjectID, client.client.ProjectRoleAssertion) + resp, err := s.accessTokenResponseFromSession(ctx, client, session, authReq.GetState(), client.client.ProjectID, client.client.ProjectRoleAssertion, client.client.AccessTokenRoleAssertion, client.client.IDTokenRoleAssertion, client.client.IDTokenUserinfoAssertion) if err != nil { op.AuthRequestError(w, r, authReq, err, authorizer) return err diff --git a/internal/api/oidc/auth_request_integration_test.go b/internal/api/oidc/auth_request_integration_test.go index c36e06c6aa..e2e6ae2f7b 100644 --- a/internal/api/oidc/auth_request_integration_test.go +++ b/internal/api/oidc/auth_request_integration_test.go @@ -54,7 +54,7 @@ func TestOPStorage_CreateAccessToken_code(t *testing.T) { tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, false) - assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID) // callback on a succeeded request must fail linkResp, err = Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{ @@ -108,7 +108,7 @@ func TestOPStorage_CreateAccessToken_implicit(t *testing.T) { require.NoError(t, err) claims, err := rp.VerifyTokens[*oidc.IDTokenClaims](context.Background(), accessToken, idToken, provider.IDTokenVerifier()) require.NoError(t, err) - assertIDTokenClaims(t, claims, User.GetUserId(), armPasskey, startTime, changeTime) + assertIDTokenClaims(t, claims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID) // callback on a succeeded request must fail linkResp, err = Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{ @@ -143,7 +143,7 @@ func TestOPStorage_CreateAccessAndRefreshTokens_code(t *testing.T) { tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, true) - assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID) } func TestOPStorage_CreateAccessAndRefreshTokens_refresh(t *testing.T) { @@ -168,14 +168,14 @@ func TestOPStorage_CreateAccessAndRefreshTokens_refresh(t *testing.T) { tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, true) - assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID) // test actual refresh grant newTokens, err := refreshTokens(t, clientID, tokens.RefreshToken) require.NoError(t, err) assertTokens(t, newTokens, true) // auth time must still be the initial - assertIDTokenClaims(t, newTokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) + assertIDTokenClaims(t, newTokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID) // refresh with an old refresh_token must fail _, err = rp.RefreshTokens[*oidc.IDTokenClaims](CTX, provider, tokens.RefreshToken, "", "") @@ -204,7 +204,7 @@ func TestOPStorage_RevokeToken_access_token(t *testing.T) { tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, true) - assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID) // revoke access token err = rp.RevokeToken(CTX, provider, tokens.AccessToken, "access_token") @@ -247,7 +247,7 @@ func TestOPStorage_RevokeToken_access_token_invalid_token_hint_type(t *testing.T tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, true) - assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID) // revoke access token err = rp.RevokeToken(CTX, provider, tokens.AccessToken, "refresh_token") @@ -284,7 +284,7 @@ func TestOPStorage_RevokeToken_refresh_token(t *testing.T) { tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, true) - assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID) // revoke refresh token -> invalidates also access token err = rp.RevokeToken(CTX, provider, tokens.RefreshToken, "refresh_token") @@ -327,7 +327,7 @@ func TestOPStorage_RevokeToken_refresh_token_invalid_token_type_hint(t *testing. tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, true) - assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID) // revoke refresh token even with a wrong hint err = rp.RevokeToken(CTX, provider, tokens.RefreshToken, "access_token") @@ -362,7 +362,7 @@ func TestOPStorage_RevokeToken_invalid_client(t *testing.T) { tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, true) - assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID) // simulate second client (not part of the audience) trying to revoke the token otherClientID, _ := createClient(t) @@ -394,7 +394,7 @@ func TestOPStorage_TerminateSession(t *testing.T) { tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, false) - assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID) // userinfo must not fail _, err = rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.Subject, provider) @@ -431,7 +431,7 @@ func TestOPStorage_TerminateSession_refresh_grant(t *testing.T) { tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, true) - assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID) // userinfo must not fail _, err = rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.Subject, provider) @@ -475,7 +475,7 @@ func TestOPStorage_TerminateSession_empty_id_token_hint(t *testing.T) { tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, false) - assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID) postLogoutRedirect, err := rp.EndSession(CTX, provider, "", logoutRedirectURI, "state") require.NoError(t, err) @@ -530,8 +530,13 @@ func assertTokens(t *testing.T, tokens *oidc.Tokens[*oidc.IDTokenClaims], requir } } -func assertIDTokenClaims(t *testing.T, claims *oidc.IDTokenClaims, userID string, arm []string, sessionStart, sessionChange time.Time) { +func assertIDTokenClaims(t *testing.T, claims *oidc.IDTokenClaims, userID string, arm []string, sessionStart, sessionChange time.Time, sessionID string) { assert.Equal(t, userID, claims.Subject) assert.Equal(t, arm, claims.AuthenticationMethodsReferences) assertOIDCTimeRange(t, claims.AuthTime, sessionStart, sessionChange) + assert.Equal(t, sessionID, claims.SessionID) + assert.Empty(t, claims.Name) + assert.Empty(t, claims.GivenName) + assert.Empty(t, claims.FamilyName) + assert.Empty(t, claims.PreferredUsername) } diff --git a/internal/api/oidc/client_integration_test.go b/internal/api/oidc/client_integration_test.go index 58fdebef07..21d54a59dc 100644 --- a/internal/api/oidc/client_integration_test.go +++ b/internal/api/oidc/client_integration_test.go @@ -122,7 +122,7 @@ func TestServer_Introspect(t *testing.T) { tokens, err := exchangeTokens(t, app.GetClientId(), code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, true) - assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID) // test actual introspection introspection, err := rs.Introspect[*oidc.IntrospectionResponse](context.Background(), resourceServer, tokens.AccessToken) @@ -317,7 +317,7 @@ func TestServer_VerifyClient(t *testing.T) { } require.NoError(t, err) assertTokens(t, tokens, false) - assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID) }) } } diff --git a/internal/api/oidc/introspect.go b/internal/api/oidc/introspect.go index 41ffc3897f..b0881b6d65 100644 --- a/internal/api/oidc/introspect.go +++ b/internal/api/oidc/introspect.go @@ -11,6 +11,7 @@ import ( "github.com/zitadel/oidc/v3/pkg/op" "github.com/zitadel/zitadel/internal/api/authz" + "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/query" "github.com/zitadel/zitadel/internal/telemetry/tracing" "github.com/zitadel/zitadel/internal/zerrors" @@ -103,7 +104,14 @@ func (s *Server) Introspect(ctx context.Context, r *op.Request[op.IntrospectionR if err = validateIntrospectionAudience(token.audience, client.clientID, client.projectID); err != nil { return nil, err } - userInfo, err := s.userInfo(ctx, token.userID, token.scope, client.projectID, client.projectRoleAssertion, true) + userInfo, err := s.userInfo( + token.userID, + token.scope, + client.projectID, + client.projectRoleAssertion, + true, + true, + )(ctx, true, domain.TriggerTypePreUserinfoCreation) if err != nil { return nil, err } diff --git a/internal/api/oidc/oidc_integration_test.go b/internal/api/oidc/oidc_integration_test.go index c7a101afd3..6df8daa132 100644 --- a/internal/api/oidc/oidc_integration_test.go +++ b/internal/api/oidc/oidc_integration_test.go @@ -77,7 +77,7 @@ func Test_ZITADEL_API_missing_audience_scope(t *testing.T) { tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, false) - assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID) ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken)) @@ -142,7 +142,7 @@ func Test_ZITADEL_API_missing_mfa_2fa_setup(t *testing.T) { code := assertCodeResponse(t, linkResp.GetCallbackUrl()) tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) - assertIDTokenClaims(t, tokens.IDTokenClaims, userResp.GetUserId(), armPassword, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, userResp.GetUserId(), armPassword, startTime, changeTime, sessionID) ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken)) @@ -173,7 +173,7 @@ func Test_ZITADEL_API_missing_mfa_policy(t *testing.T) { code := assertCodeResponse(t, linkResp.GetCallbackUrl()) tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) - assertIDTokenClaims(t, tokens.IDTokenClaims, userID, armPassword, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, userID, armPassword, startTime, changeTime, sessionID) ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken)) @@ -227,7 +227,7 @@ func Test_ZITADEL_API_success(t *testing.T) { tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, false) - assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID) ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken)) @@ -261,7 +261,7 @@ func Test_ZITADEL_API_glob_redirects(t *testing.T) { tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, false) - assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID) ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken)) @@ -290,7 +290,7 @@ func Test_ZITADEL_API_inactive_access_token(t *testing.T) { tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, true) - assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID) // make sure token works ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken)) @@ -332,7 +332,7 @@ func Test_ZITADEL_API_terminated_session(t *testing.T) { tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, true) - assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID) // make sure token works ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken)) @@ -402,7 +402,7 @@ func Test_ZITADEL_API_terminated_session_user_disabled(t *testing.T) { tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, true) - assertIDTokenClaims(t, tokens.IDTokenClaims, disabledUser.GetUserId(), armPassword, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, disabledUser.GetUserId(), armPassword, startTime, changeTime, sessionID) // make sure token works ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken)) diff --git a/internal/api/oidc/token.go b/internal/api/oidc/token.go index c45eb98acb..be3a30ed73 100644 --- a/internal/api/oidc/token.go +++ b/internal/api/oidc/token.go @@ -29,8 +29,8 @@ In some cases step 1 till 3 are completely implemented in the command package, for example the v2 code exchange and refresh token. */ -func (s *Server) accessTokenResponseFromSession(ctx context.Context, client op.Client, session *command.OIDCSession, state, projectID string, projectRoleAssertion bool) (_ *oidc.AccessTokenResponse, err error) { - getUserInfo := s.getUserInfoOnce(session.UserID, projectID, projectRoleAssertion, session.Scope) +func (s *Server) accessTokenResponseFromSession(ctx context.Context, client op.Client, session *command.OIDCSession, state, projectID string, projectRoleAssertion, accessTokenRoleAssertion, idTokenRoleAssertion, userInfoAssertion bool) (_ *oidc.AccessTokenResponse, err error) { + getUserInfo := s.getUserInfo(session.UserID, projectID, projectRoleAssertion, userInfoAssertion, session.Scope) getSigner := s.getSignerOnce() resp := &oidc.AccessTokenResponse{ @@ -43,7 +43,7 @@ func (s *Server) accessTokenResponseFromSession(ctx context.Context, client op.C // If the session does not have a token ID, it is an implicit ID-Token only response. if session.TokenID != "" { if client.AccessTokenType() == op.AccessTokenTypeJWT { - resp.AccessToken, err = s.createJWT(ctx, client, session, getUserInfo, getSigner) + resp.AccessToken, err = s.createJWT(ctx, client, session, getUserInfo, accessTokenRoleAssertion, getSigner) } else { resp.AccessToken, err = op.CreateBearerToken(session.TokenID, session.UserID, s.opCrypto) } @@ -53,7 +53,7 @@ func (s *Server) accessTokenResponseFromSession(ctx context.Context, client op.C } if slices.Contains(session.Scope, oidc.ScopeOpenID) { - resp.IDToken, _, err = s.createIDToken(ctx, client, getUserInfo, getSigner, session.SessionID, resp.AccessToken, session.Audience, session.AuthMethods, session.AuthTime, session.Nonce, session.Actor) + resp.IDToken, _, err = s.createIDToken(ctx, client, getUserInfo, idTokenRoleAssertion, getSigner, session.SessionID, resp.AccessToken, session.Audience, session.AuthMethods, session.AuthTime, session.Nonce, session.Actor) } return resp, err } @@ -92,31 +92,22 @@ func (s *Server) getSignerOnce() signerFunc { } // userInfoFunc is a getter function that allows add-hoc retrieval of a user. -type userInfoFunc func(ctx context.Context) (*oidc.UserInfo, error) +type userInfoFunc func(ctx context.Context, roleAssertion bool, triggerType domain.TriggerType) (*oidc.UserInfo, error) -// getUserInfoOnce returns a function which retrieves userinfo from the database once. -// Repeated calls of the returned function return the same results. -func (s *Server) getUserInfoOnce(userID, projectID string, projectRoleAssertion bool, scope []string) userInfoFunc { - var ( - once sync.Once - userInfo *oidc.UserInfo - err error - ) - return func(ctx context.Context) (*oidc.UserInfo, error) { - once.Do(func() { - ctx, span := tracing.NewSpan(ctx) - defer func() { span.EndWithError(err) }() - userInfo, err = s.userInfo(ctx, userID, scope, projectID, projectRoleAssertion, false) - }) - return userInfo, err +// getUserInfo returns a function which retrieves userinfo from the database once. +// However, each time, role claims are asserted and also action flows will trigger. +func (s *Server) getUserInfo(userID, projectID string, projectRoleAssertion, userInfoAssertion bool, scope []string) userInfoFunc { + userInfo := s.userInfo(userID, scope, projectID, projectRoleAssertion, userInfoAssertion, false) + return func(ctx context.Context, roleAssertion bool, triggerType domain.TriggerType) (*oidc.UserInfo, error) { + return userInfo(ctx, roleAssertion, triggerType) } } -func (*Server) createIDToken(ctx context.Context, client op.Client, getUserInfo userInfoFunc, getSigningKey signerFunc, sessionID, accessToken string, audience []string, authMethods []domain.UserAuthMethodType, authTime time.Time, nonce string, actor *domain.TokenActor) (idToken string, exp uint64, err error) { +func (*Server) createIDToken(ctx context.Context, client op.Client, getUserInfo userInfoFunc, roleAssertion bool, getSigningKey signerFunc, sessionID, accessToken string, audience []string, authMethods []domain.UserAuthMethodType, authTime time.Time, nonce string, actor *domain.TokenActor) (idToken string, exp uint64, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - userInfo, err := getUserInfo(ctx) + userInfo, err := getUserInfo(ctx, roleAssertion, domain.TriggerTypePreUserinfoCreation) if err != nil { return "", 0, err } @@ -156,11 +147,11 @@ func timeToOIDCExpiresIn(exp time.Time) uint64 { return uint64(time.Until(exp) / time.Second) } -func (*Server) createJWT(ctx context.Context, client op.Client, session *command.OIDCSession, getUserInfo userInfoFunc, getSigner signerFunc) (_ string, err error) { +func (s *Server) createJWT(ctx context.Context, client op.Client, session *command.OIDCSession, getUserInfo userInfoFunc, assertRoles bool, getSigner signerFunc) (_ string, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - userInfo, err := getUserInfo(ctx) + userInfo, err := getUserInfo(ctx, assertRoles, domain.TriggerTypePreAccessTokenCreation) if err != nil { return "", err } diff --git a/internal/api/oidc/token_client_credentials.go b/internal/api/oidc/token_client_credentials.go index e0cb29770b..4b3bf20acd 100644 --- a/internal/api/oidc/token_client_credentials.go +++ b/internal/api/oidc/token_client_credentials.go @@ -47,5 +47,5 @@ func (s *Server) ClientCredentialsExchange(ctx context.Context, r *op.ClientRequ false, ) - return response(s.accessTokenResponseFromSession(ctx, client, session, "", "", false)) + return response(s.accessTokenResponseFromSession(ctx, client, session, "", "", false, true, false, false)) } diff --git a/internal/api/oidc/token_client_credentials_integration_test.go b/internal/api/oidc/token_client_credentials_integration_test.go index 3517438596..21a1c4de75 100644 --- a/internal/api/oidc/token_client_credentials_integration_test.go +++ b/internal/api/oidc/token_client_credentials_integration_test.go @@ -4,6 +4,7 @@ package oidc_test import ( "testing" + "time" "github.com/brianvoe/gofakeit/v6" "github.com/stretchr/testify/assert" @@ -18,10 +19,13 @@ import ( ) func TestServer_ClientCredentialsExchange(t *testing.T) { - userID, clientID, clientSecret, err := Tester.CreateOIDCCredentialsClient(CTX) + machine, name, clientID, clientSecret, err := Tester.CreateOIDCCredentialsClient(CTX) require.NoError(t, err) type claims struct { + name string + username string + updated time.Time resourceOwnerID any resourceOwnerName any resourceOwnerPrimaryDomain any @@ -78,6 +82,17 @@ func TestServer_ClientCredentialsExchange(t *testing.T) { clientSecret: clientSecret, scope: []string{oidc.ScopeOpenID}, }, + { + name: "openid, profile, email", + clientID: clientID, + clientSecret: clientSecret, + scope: []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail}, + wantClaims: claims{ + name: name, + username: name, + updated: machine.GetDetails().GetChangeDate().AsTime(), + }, + }, { name: "org id and domain scope", clientID: clientID, @@ -132,12 +147,20 @@ func TestServer_ClientCredentialsExchange(t *testing.T) { } require.NoError(t, err) require.NotNil(t, tokens) - userinfo, err := rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, oidc.BearerToken, userID, provider) + userinfo, err := rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, oidc.BearerToken, machine.GetUserId(), provider) require.NoError(t, err) assert.Equal(t, tt.wantClaims.resourceOwnerID, userinfo.Claims[oidc_api.ClaimResourceOwnerID]) assert.Equal(t, tt.wantClaims.resourceOwnerName, userinfo.Claims[oidc_api.ClaimResourceOwnerName]) assert.Equal(t, tt.wantClaims.resourceOwnerPrimaryDomain, userinfo.Claims[oidc_api.ClaimResourceOwnerPrimaryDomain]) assert.Equal(t, tt.wantClaims.orgDomain, userinfo.Claims[domain.OrgDomainPrimaryClaim]) + assert.Equal(t, tt.wantClaims.name, userinfo.Name) + assert.Equal(t, tt.wantClaims.username, userinfo.PreferredUsername) + assertOIDCTime(t, userinfo.UpdatedAt, tt.wantClaims.updated) + assert.Empty(t, userinfo.UserInfoProfile.FamilyName) + assert.Empty(t, userinfo.UserInfoProfile.GivenName) + assert.Empty(t, userinfo.UserInfoEmail) + assert.Empty(t, userinfo.UserInfoPhone) + assert.Empty(t, userinfo.Address) }) } } diff --git a/internal/api/oidc/token_code.go b/internal/api/oidc/token_code.go index 85aa847579..b7ccd1d22f 100644 --- a/internal/api/oidc/token_code.go +++ b/internal/api/oidc/token_code.go @@ -49,7 +49,7 @@ func (s *Server) CodeExchange(ctx context.Context, r *op.ClientRequest[oidc.Acce if err != nil { return nil, err } - return response(s.accessTokenResponseFromSession(ctx, client, session, state, client.client.ProjectID, client.client.ProjectRoleAssertion)) + return response(s.accessTokenResponseFromSession(ctx, client, session, state, client.client.ProjectID, client.client.ProjectRoleAssertion, client.client.AccessTokenRoleAssertion, client.client.IDTokenRoleAssertion, client.client.IDTokenUserinfoAssertion)) } // codeExchangeV1 creates a v2 token from a v1 auth request. diff --git a/internal/api/oidc/token_device.go b/internal/api/oidc/token_device.go index c70fa9c1e9..b574af1260 100644 --- a/internal/api/oidc/token_device.go +++ b/internal/api/oidc/token_device.go @@ -26,7 +26,7 @@ func (s *Server) DeviceToken(ctx context.Context, r *op.ClientRequest[oidc.Devic } session, err := s.command.CreateOIDCSessionFromDeviceAuth(ctx, r.Data.DeviceCode) if err == nil { - return response(s.accessTokenResponseFromSession(ctx, client, session, "", client.client.ProjectID, client.client.ProjectRoleAssertion)) + return response(s.accessTokenResponseFromSession(ctx, client, session, "", client.client.ProjectID, client.client.ProjectRoleAssertion, client.client.AccessTokenRoleAssertion, client.client.IDTokenRoleAssertion, client.client.IDTokenUserinfoAssertion)) } if errors.Is(err, context.DeadlineExceeded) { return nil, oidc.ErrSlowDown().WithParent(err) diff --git a/internal/api/oidc/token_exchange.go b/internal/api/oidc/token_exchange.go index 31b1a37db3..bd19e565cc 100644 --- a/internal/api/oidc/token_exchange.go +++ b/internal/api/oidc/token_exchange.go @@ -218,7 +218,7 @@ func validateTokenExchangeAudience(requestedAudience, subjectAudience, actorAudi // Both tokens may point to the same object (subjectToken) in case of a regular Token Exchange. // When the subject and actor Tokens point to different objects, the new tokens will be for impersonation / delegation. func (s *Server) createExchangeTokens(ctx context.Context, tokenType oidc.TokenType, client *Client, subjectToken, actorToken *exchangeToken, audience, scopes []string) (_ *oidc.TokenExchangeResponse, err error) { - getUserInfo := s.getUserInfoOnce(subjectToken.userID, client.client.ProjectID, client.client.ProjectRoleAssertion, scopes) + getUserInfo := s.getUserInfo(subjectToken.userID, client.client.ProjectID, client.client.ProjectRoleAssertion, client.IDTokenUserinfoClaimsAssertion(), scopes) getSigner := s.getSignerOnce() resp := &oidc.TokenExchangeResponse{ @@ -240,12 +240,12 @@ func (s *Server) createExchangeTokens(ctx context.Context, tokenType oidc.TokenT resp.IssuedTokenType = oidc.AccessTokenType case oidc.JWTTokenType: - resp.AccessToken, resp.RefreshToken, resp.ExpiresIn, err = s.createExchangeJWT(ctx, client, getUserInfo, getSigner, subjectToken.userID, subjectToken.resourceOwner, audience, scopes, actorToken.authMethods, actorToken.authTime, subjectToken.preferredLanguage, reason, actor) + resp.AccessToken, resp.RefreshToken, resp.ExpiresIn, err = s.createExchangeJWT(ctx, client, getUserInfo, client.client.AccessTokenRoleAssertion, getSigner, subjectToken.userID, subjectToken.resourceOwner, audience, scopes, actorToken.authMethods, actorToken.authTime, subjectToken.preferredLanguage, reason, actor) resp.TokenType = oidc.BearerToken resp.IssuedTokenType = oidc.JWTTokenType case oidc.IDTokenType: - resp.AccessToken, resp.ExpiresIn, err = s.createIDToken(ctx, client, getUserInfo, getSigner, "", resp.AccessToken, audience, actorToken.authMethods, actorToken.authTime, "", actor) + resp.AccessToken, resp.ExpiresIn, err = s.createIDToken(ctx, client, getUserInfo, client.client.IDTokenRoleAssertion, getSigner, "", resp.AccessToken, audience, actorToken.authMethods, actorToken.authTime, "", actor) resp.TokenType = TokenTypeNA resp.IssuedTokenType = oidc.IDTokenType @@ -259,7 +259,7 @@ func (s *Server) createExchangeTokens(ctx context.Context, tokenType oidc.TokenT } if slices.Contains(scopes, oidc.ScopeOpenID) && tokenType != oidc.IDTokenType { - resp.IDToken, _, err = s.createIDToken(ctx, client, getUserInfo, getSigner, sessionID, resp.AccessToken, audience, actorToken.authMethods, actorToken.authTime, "", actor) + resp.IDToken, _, err = s.createIDToken(ctx, client, getUserInfo, client.client.IDTokenRoleAssertion, getSigner, sessionID, resp.AccessToken, audience, actorToken.authMethods, actorToken.authTime, "", actor) if err != nil { return nil, err } @@ -313,6 +313,7 @@ func (s *Server) createExchangeJWT( ctx context.Context, client *Client, getUserInfo userInfoFunc, + roleAssertion bool, getSigner signerFunc, userID, resourceOwner string, @@ -342,7 +343,7 @@ func (s *Server) createExchangeJWT( actor, slices.Contains(scope, oidc.ScopeOfflineAccess), ) - accessToken, err = s.createJWT(ctx, client, session, getUserInfo, getSigner) + accessToken, err = s.createJWT(ctx, client, session, getUserInfo, roleAssertion, getSigner) if err != nil { return "", "", 0, err } diff --git a/internal/api/oidc/token_jwt_profile.go b/internal/api/oidc/token_jwt_profile.go index b23a24f77f..fc0c31a6eb 100644 --- a/internal/api/oidc/token_jwt_profile.go +++ b/internal/api/oidc/token_jwt_profile.go @@ -54,7 +54,7 @@ func (s *Server) JWTProfile(ctx context.Context, r *op.Request[oidc.JWTProfileGr nil, false, ) - return response(s.accessTokenResponseFromSession(ctx, client, session, "", "", false)) + return response(s.accessTokenResponseFromSession(ctx, client, session, "", "", false, true, false, false)) } func (s *Server) verifyJWTProfile(ctx context.Context, req *oidc.JWTProfileGrantRequest) (user *query.User, tokenRequest *oidc.JWTTokenRequest, err error) { diff --git a/internal/api/oidc/token_jwt_profile_integration_test.go b/internal/api/oidc/token_jwt_profile_integration_test.go index b80ea09bcd..0ad8d76da2 100644 --- a/internal/api/oidc/token_jwt_profile_integration_test.go +++ b/internal/api/oidc/token_jwt_profile_integration_test.go @@ -4,6 +4,7 @@ package oidc_test import ( "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -16,10 +17,13 @@ import ( ) func TestServer_JWTProfile(t *testing.T) { - userID, keyData, err := Tester.CreateOIDCJWTProfileClient(CTX) + user, name, keyData, err := Tester.CreateOIDCJWTProfileClient(CTX) require.NoError(t, err) type claims struct { + name string + username string + updated time.Time resourceOwnerID any resourceOwnerName any resourceOwnerPrimaryDomain any @@ -37,6 +41,16 @@ func TestServer_JWTProfile(t *testing.T) { keyData: keyData, scope: []string{oidc.ScopeOpenID}, }, + { + name: "openid, profile, email", + keyData: keyData, + scope: []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail}, + wantClaims: claims{ + name: name, + username: name, + updated: user.GetDetails().GetChangeDate().AsTime(), + }, + }, { name: "org id and domain scope", keyData: keyData, @@ -92,12 +106,20 @@ func TestServer_JWTProfile(t *testing.T) { provider, err := rp.NewRelyingPartyOIDC(CTX, Tester.OIDCIssuer(), "", "", redirectURI, tt.scope) require.NoError(t, err) - userinfo, err := rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, oidc.BearerToken, userID, provider) + userinfo, err := rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, oidc.BearerToken, user.GetUserId(), provider) require.NoError(t, err) assert.Equal(t, tt.wantClaims.resourceOwnerID, userinfo.Claims[oidc_api.ClaimResourceOwnerID]) assert.Equal(t, tt.wantClaims.resourceOwnerName, userinfo.Claims[oidc_api.ClaimResourceOwnerName]) assert.Equal(t, tt.wantClaims.resourceOwnerPrimaryDomain, userinfo.Claims[oidc_api.ClaimResourceOwnerPrimaryDomain]) assert.Equal(t, tt.wantClaims.orgDomain, userinfo.Claims[domain.OrgDomainPrimaryClaim]) + assert.Equal(t, tt.wantClaims.name, userinfo.Name) + assert.Equal(t, tt.wantClaims.username, userinfo.PreferredUsername) + assertOIDCTime(t, userinfo.UpdatedAt, tt.wantClaims.updated) + assert.Empty(t, userinfo.UserInfoProfile.FamilyName) + assert.Empty(t, userinfo.UserInfoProfile.GivenName) + assert.Empty(t, userinfo.UserInfoEmail) + assert.Empty(t, userinfo.UserInfoPhone) + assert.Empty(t, userinfo.Address) }) } } diff --git a/internal/api/oidc/token_refresh.go b/internal/api/oidc/token_refresh.go index 66a8b8a263..1dcce2879a 100644 --- a/internal/api/oidc/token_refresh.go +++ b/internal/api/oidc/token_refresh.go @@ -28,7 +28,7 @@ func (s *Server) RefreshToken(ctx context.Context, r *op.ClientRequest[oidc.Refr session, err := s.command.ExchangeOIDCSessionRefreshAndAccessToken(ctx, r.Data.RefreshToken, r.Data.Scopes, refreshTokenComplianceChecker()) if err == nil { - return response(s.accessTokenResponseFromSession(ctx, client, session, "", client.client.ProjectID, client.client.ProjectRoleAssertion)) + return response(s.accessTokenResponseFromSession(ctx, client, session, "", client.client.ProjectID, client.client.ProjectRoleAssertion, client.client.AccessTokenRoleAssertion, client.client.IDTokenRoleAssertion, client.client.IDTokenUserinfoAssertion)) } else if errors.Is(err, zerrors.ThrowPreconditionFailed(nil, "OIDCS-JOI23", "Errors.OIDCSession.RefreshTokenInvalid")) { // We try again for v1 tokens when we encountered specific parsing error return s.refreshTokenV1(ctx, client, r) @@ -78,7 +78,7 @@ func (s *Server) refreshTokenV1(ctx context.Context, client *Client, r *op.Clien return nil, err } - return response(s.accessTokenResponseFromSession(ctx, client, session, "", client.client.ProjectID, client.client.ProjectRoleAssertion)) + return response(s.accessTokenResponseFromSession(ctx, client, session, "", client.client.ProjectID, client.client.ProjectRoleAssertion, client.client.AccessTokenRoleAssertion, client.client.IDTokenRoleAssertion, client.client.IDTokenUserinfoAssertion)) } // refreshTokenComplianceChecker validates that the requested scope is a subset of the original auth request scope. diff --git a/internal/api/oidc/userinfo.go b/internal/api/oidc/userinfo.go index c21b746c49..7d352457bc 100644 --- a/internal/api/oidc/userinfo.go +++ b/internal/api/oidc/userinfo.go @@ -8,6 +8,7 @@ import ( "net/http" "slices" "strings" + "sync" "github.com/dop251/goja" "github.com/zitadel/logging" @@ -55,7 +56,14 @@ func (s *Server) UserInfo(ctx context.Context, r *op.Request[oidc.UserInfoReques } } - userInfo, err := s.userInfo(ctx, token.userID, token.scope, projectID, assertion, false) + userInfo, err := s.userInfo( + token.userID, + token.scope, + projectID, + assertion, + true, + false, + )(ctx, true, domain.TriggerTypePreUserinfoCreation) if err != nil { return nil, err } @@ -66,24 +74,44 @@ func (s *Server) UserInfo(ctx context.Context, r *op.Request[oidc.UserInfoReques // The returned UserInfo contains standard and reserved claims, documented // here: https://zitadel.com/docs/apis/openidoauth/claims. // +// User information is only retrieved once from the database. +// However, each time, role claims are asserted and also action flows will trigger. +// // projectID is an optional parameter which defines the default audience when there are any (or all) role claims requested. // projectRoleAssertion sets the default of returning all project roles, only if no specific roles were requested in the scope. +// roleAssertion decides whether the roles will be returned (in the token or response) +// userInfoAssertion decides whether the user information (profile data like name, email, ...) are returned // // currentProjectOnly can be set to use the current project ID only and ignore the audience from the scope. // It should be set in cases where the client doesn't need to know roles outside its own project, // for example an introspection client. -func (s *Server) userInfo(ctx context.Context, userID string, scope []string, projectID string, projectRoleAssertion, currentProjectOnly bool) (_ *oidc.UserInfo, err error) { - ctx, span := tracing.NewSpan(ctx) - defer func() { span.EndWithError(err) }() +func (s *Server) userInfo( + userID string, + scope []string, + projectID string, + projectRoleAssertion, userInfoAssertion, currentProjectOnly bool, +) func(ctx context.Context, roleAssertion bool, triggerType domain.TriggerType) (_ *oidc.UserInfo, err error) { + var ( + once sync.Once + userInfo *oidc.UserInfo + qu *query.OIDCUserInfo + roleAudience, requestedRoles []string + ) + return func(ctx context.Context, roleAssertion bool, triggerType domain.TriggerType) (_ *oidc.UserInfo, err error) { + once.Do(func() { + ctx, span := tracing.NewSpan(ctx) + defer func() { span.EndWithError(err) }() - roleAudience, requestedRoles := prepareRoles(ctx, scope, projectID, projectRoleAssertion, currentProjectOnly) - qu, err := s.query.GetOIDCUserInfo(ctx, userID, roleAudience) - if err != nil { - return nil, err + roleAudience, requestedRoles = prepareRoles(ctx, scope, projectID, projectRoleAssertion, currentProjectOnly) + qu, err = s.query.GetOIDCUserInfo(ctx, userID, roleAudience) + if err != nil { + return + } + userInfo = userInfoToOIDC(qu, userInfoAssertion, scope, s.assetAPIPrefix(ctx)) + }) + userInfoWithRoles := assertRoles(projectID, qu, roleAudience, requestedRoles, roleAssertion, userInfo) + return userInfoWithRoles, s.userinfoFlows(ctx, qu, userInfoWithRoles, triggerType) } - - userInfo := userInfoToOIDC(projectID, qu, scope, roleAudience, requestedRoles, s.assetAPIPrefix(ctx)) - return userInfo, s.userinfoFlows(ctx, qu, userInfo) } // prepareRoles scans the requested scopes and builds the requested roles @@ -120,20 +148,32 @@ func prepareRoles(ctx context.Context, scope []string, projectID string, project return roleAudience, requestedRoles } -func userInfoToOIDC(projectID string, user *query.OIDCUserInfo, scope, roleAudience, requestedRoles []string, assetPrefix string) *oidc.UserInfo { +func userInfoToOIDC(user *query.OIDCUserInfo, userInfoAssertion bool, scope []string, assetPrefix string) *oidc.UserInfo { out := new(oidc.UserInfo) for _, s := range scope { switch s { case oidc.ScopeOpenID: out.Subject = user.User.ID case oidc.ScopeEmail: + if !userInfoAssertion { + continue + } out.UserInfoEmail = userInfoEmailToOIDC(user.User) case oidc.ScopeProfile: + if !userInfoAssertion { + continue + } out.UserInfoProfile = userInfoProfileToOidc(user.User, assetPrefix) case oidc.ScopePhone: + if !userInfoAssertion { + continue + } out.UserInfoPhone = userInfoPhoneToOIDC(user.User) case oidc.ScopeAddress: - //TODO: handle address for human users as soon as implemented + if !userInfoAssertion { + continue + } + // TODO: handle address for human users as soon as implemented case ScopeUserMetaData: setUserInfoMetadata(user.Metadata, out) case ScopeResourceOwner: @@ -148,12 +188,19 @@ func userInfoToOIDC(projectID string, user *query.OIDCUserInfo, scope, roleAudie } } } + return out +} +func assertRoles(projectID string, user *query.OIDCUserInfo, roleAudience, requestedRoles []string, assertion bool, info *oidc.UserInfo) *oidc.UserInfo { + if !assertion { + return info + } + userInfo := *info // prevent returning obtained grants if none where requested if (projectID != "" && len(requestedRoles) > 0) || len(roleAudience) > 0 { - setUserInfoRoleClaims(out, newProjectRoles(projectID, user.UserGrants, requestedRoles)) + setUserInfoRoleClaims(&userInfo, newProjectRoles(projectID, user.UserGrants, requestedRoles)) } - return out + return &userInfo } func userInfoEmailToOIDC(user *query.User) oidc.UserInfoEmail { @@ -230,11 +277,11 @@ func setUserInfoRoleClaims(userInfo *oidc.UserInfo, roles *projectsRoles) { } } -func (s *Server) userinfoFlows(ctx context.Context, qu *query.OIDCUserInfo, userInfo *oidc.UserInfo) (err error) { +func (s *Server) userinfoFlows(ctx context.Context, qu *query.OIDCUserInfo, userInfo *oidc.UserInfo, triggerType domain.TriggerType) (err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - queriedActions, err := s.query.GetActiveActionsByFlowAndTriggerType(ctx, domain.FlowTypeCustomiseToken, domain.TriggerTypePreUserinfoCreation, qu.User.ResourceOwner) + queriedActions, err := s.query.GetActiveActionsByFlowAndTriggerType(ctx, domain.FlowTypeCustomiseToken, triggerType, qu.User.ResourceOwner) if err != nil { return err } diff --git a/internal/api/oidc/userinfo_integration_test.go b/internal/api/oidc/userinfo_integration_test.go index 22e688ff4b..7f39ed38ba 100644 --- a/internal/api/oidc/userinfo_integration_test.go +++ b/internal/api/oidc/userinfo_integration_test.go @@ -231,9 +231,9 @@ func TestServer_UserInfo_Issue6662(t *testing.T) { project, err := Tester.CreateProject(CTX) projectID := project.GetId() require.NoError(t, err) - userID, clientID, clientSecret, err := Tester.CreateOIDCCredentialsClient(CTX) + user, _, clientID, clientSecret, err := Tester.CreateOIDCCredentialsClient(CTX) require.NoError(t, err) - addProjectRolesGrants(t, userID, projectID, roleFoo, roleBar) + addProjectRolesGrants(t, user.GetUserId(), projectID, roleFoo, roleBar) scope := []string{oidc.ScopeProfile, oidc.ScopeOpenID, oidc.ScopeEmail, oidc.ScopeOfflineAccess, oidc_api.ScopeProjectRolePrefix + roleFoo, @@ -245,7 +245,7 @@ func TestServer_UserInfo_Issue6662(t *testing.T) { tokens, err := rp.ClientCredentials(CTX, provider, nil) require.NoError(t, err) - userinfo, err := rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, tokens.TokenType, userID, provider) + userinfo, err := rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, tokens.TokenType, user.GetUserId(), provider) require.NoError(t, err) assertProjectRoleClaims(t, projectID, userinfo.Claims, false, roleFoo) } @@ -291,7 +291,7 @@ func getTokens(t *testing.T, clientID string, scope []string) *oidc.Tokens[*oidc tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, true) - assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID) return tokens } diff --git a/internal/api/oidc/userinfo_test.go b/internal/api/oidc/userinfo_test.go index 65354a4040..741f7eed36 100644 --- a/internal/api/oidc/userinfo_test.go +++ b/internal/api/oidc/userinfo_test.go @@ -3,7 +3,6 @@ package oidc import ( "context" "encoding/base64" - "fmt" "testing" "time" @@ -267,11 +266,9 @@ func Test_userInfoToOIDC(t *testing.T) { } type args struct { - projectID string - user *query.OIDCUserInfo - scope []string - roleAudience []string - requestedRoles []string + user *query.OIDCUserInfo + userInfoAssertion bool + scope []string } tests := []struct { name string @@ -281,25 +278,22 @@ func Test_userInfoToOIDC(t *testing.T) { { name: "human, empty", args: args{ - projectID: "project1", - user: humanUserInfo, + user: humanUserInfo, }, want: &oidc.UserInfo{}, }, { name: "machine, empty", args: args{ - projectID: "project1", - user: machineUserInfo, + user: machineUserInfo, }, want: &oidc.UserInfo{}, }, { name: "human, scope openid", args: args{ - projectID: "project1", - user: humanUserInfo, - scope: []string{oidc.ScopeOpenID}, + user: humanUserInfo, + scope: []string{oidc.ScopeOpenID}, }, want: &oidc.UserInfo{ Subject: "human1", @@ -308,20 +302,19 @@ func Test_userInfoToOIDC(t *testing.T) { { name: "machine, scope openid", args: args{ - projectID: "project1", - user: machineUserInfo, - scope: []string{oidc.ScopeOpenID}, + user: machineUserInfo, + scope: []string{oidc.ScopeOpenID}, }, want: &oidc.UserInfo{ Subject: "machine1", }, }, { - name: "human, scope email", + name: "human, scope email, profileInfoAssertion", args: args{ - projectID: "project1", - user: humanUserInfo, - scope: []string{oidc.ScopeEmail}, + user: humanUserInfo, + userInfoAssertion: true, + scope: []string{oidc.ScopeEmail}, }, want: &oidc.UserInfo{ UserInfoEmail: oidc.UserInfoEmail{ @@ -331,22 +324,29 @@ func Test_userInfoToOIDC(t *testing.T) { }, }, { - name: "machine, scope email", + name: "human, scope email", args: args{ - projectID: "project1", - user: machineUserInfo, - scope: []string{oidc.ScopeEmail}, + user: humanUserInfo, + scope: []string{oidc.ScopeEmail}, + }, + want: &oidc.UserInfo{}, + }, + { + name: "machine, scope email, profileInfoAssertion", + args: args{ + user: machineUserInfo, + scope: []string{oidc.ScopeEmail}, }, want: &oidc.UserInfo{ UserInfoEmail: oidc.UserInfoEmail{}, }, }, { - name: "human, scope profile", + name: "human, scope profile, profileInfoAssertion", args: args{ - projectID: "project1", - user: humanUserInfo, - scope: []string{oidc.ScopeProfile}, + user: humanUserInfo, + userInfoAssertion: true, + scope: []string{oidc.ScopeProfile}, }, want: &oidc.UserInfo{ UserInfoProfile: oidc.UserInfoProfile{ @@ -363,11 +363,11 @@ func Test_userInfoToOIDC(t *testing.T) { }, }, { - name: "machine, scope profile", + name: "machine, scope profile, profileInfoAssertion", args: args{ - projectID: "project1", - user: machineUserInfo, - scope: []string{oidc.ScopeProfile}, + user: machineUserInfo, + userInfoAssertion: true, + scope: []string{oidc.ScopeProfile}, }, want: &oidc.UserInfo{ UserInfoProfile: oidc.UserInfoProfile{ @@ -378,11 +378,19 @@ func Test_userInfoToOIDC(t *testing.T) { }, }, { - name: "human, scope phone", + name: "machine, scope profile", args: args{ - projectID: "project1", - user: humanUserInfo, - scope: []string{oidc.ScopePhone}, + user: machineUserInfo, + scope: []string{oidc.ScopeProfile}, + }, + want: &oidc.UserInfo{}, + }, + { + name: "human, scope phone, profileInfoAssertion", + args: args{ + user: humanUserInfo, + userInfoAssertion: true, + scope: []string{oidc.ScopePhone}, }, want: &oidc.UserInfo{ UserInfoPhone: oidc.UserInfoPhone{ @@ -391,12 +399,19 @@ func Test_userInfoToOIDC(t *testing.T) { }, }, }, + { + name: "human, scope phone", + args: args{ + user: humanUserInfo, + scope: []string{oidc.ScopePhone}, + }, + want: &oidc.UserInfo{}, + }, { name: "machine, scope phone", args: args{ - projectID: "project1", - user: machineUserInfo, - scope: []string{oidc.ScopePhone}, + user: machineUserInfo, + scope: []string{oidc.ScopePhone}, }, want: &oidc.UserInfo{ UserInfoPhone: oidc.UserInfoPhone{}, @@ -405,9 +420,8 @@ func Test_userInfoToOIDC(t *testing.T) { { name: "human, scope metadata", args: args{ - projectID: "project1", - user: humanUserInfo, - scope: []string{ScopeUserMetaData}, + user: humanUserInfo, + scope: []string{ScopeUserMetaData}, }, want: &oidc.UserInfo{ Claims: map[string]any{ @@ -421,18 +435,16 @@ func Test_userInfoToOIDC(t *testing.T) { { name: "machine, scope metadata, none found", args: args{ - projectID: "project1", - user: machineUserInfo, - scope: []string{ScopeUserMetaData}, + user: machineUserInfo, + scope: []string{ScopeUserMetaData}, }, want: &oidc.UserInfo{}, }, { name: "machine, scope resource owner", args: args{ - projectID: "project1", - user: machineUserInfo, - scope: []string{ScopeResourceOwner}, + user: machineUserInfo, + scope: []string{ScopeResourceOwner}, }, want: &oidc.UserInfo{ Claims: map[string]any{ @@ -445,9 +457,8 @@ func Test_userInfoToOIDC(t *testing.T) { { name: "human, scope org primary domain prefix", args: args{ - projectID: "project1", - user: humanUserInfo, - scope: []string{domain.OrgDomainPrimaryScope + "foo.com"}, + user: humanUserInfo, + scope: []string{domain.OrgDomainPrimaryScope + "foo.com"}, }, want: &oidc.UserInfo{ Claims: map[string]any{ @@ -458,9 +469,8 @@ func Test_userInfoToOIDC(t *testing.T) { { name: "machine, scope org id", args: args{ - projectID: "project1", - user: machineUserInfo, - scope: []string{domain.OrgIDScope + "orgID"}, + user: machineUserInfo, + scope: []string{domain.OrgIDScope + "orgID"}, }, want: &oidc.UserInfo{ Claims: map[string]any{ @@ -471,50 +481,11 @@ func Test_userInfoToOIDC(t *testing.T) { }, }, }, - { - name: "human, roleAudience", - args: args{ - projectID: "project1", - user: humanUserInfo, - roleAudience: []string{"project1"}, - }, - want: &oidc.UserInfo{ - Claims: map[string]any{ - ClaimProjectRoles: projectRoles{ - "role1": {"orgID": "orgDomain"}, - "role2": {"orgID": "orgDomain"}, - }, - fmt.Sprintf(ClaimProjectRolesFormat, "project1"): projectRoles{ - "role1": {"orgID": "orgDomain"}, - "role2": {"orgID": "orgDomain"}, - }, - }, - }, - }, - { - name: "human, requested roles", - args: args{ - projectID: "project1", - user: humanUserInfo, - roleAudience: []string{"project1"}, - requestedRoles: []string{"role2"}, - }, - want: &oidc.UserInfo{ - Claims: map[string]any{ - ClaimProjectRoles: projectRoles{ - "role2": {"orgID": "orgDomain"}, - }, - fmt.Sprintf(ClaimProjectRolesFormat, "project1"): projectRoles{ - "role2": {"orgID": "orgDomain"}, - }, - }, - }, - }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { assetPrefix := "https://foo.com/assets" - got := userInfoToOIDC(tt.args.projectID, tt.args.user, tt.args.scope, tt.args.roleAudience, tt.args.requestedRoles, assetPrefix) + got := userInfoToOIDC(tt.args.user, tt.args.userInfoAssertion, tt.args.scope, assetPrefix) assert.Equal(t, tt.want, got) }) } diff --git a/internal/integration/oidc.go b/internal/integration/oidc.go index 3ba655c65b..2d7d6a105e 100644 --- a/internal/integration/oidc.go +++ b/internal/integration/oidc.go @@ -281,42 +281,42 @@ func CheckRedirect(req *http.Request) (*url.URL, error) { return resp.Location() } -func (s *Tester) CreateOIDCCredentialsClient(ctx context.Context) (userID, clientID, clientSecret string, err error) { - name := gofakeit.Username() - user, err := s.Client.Mgmt.AddMachineUser(ctx, &management.AddMachineUserRequest{ +func (s *Tester) CreateOIDCCredentialsClient(ctx context.Context) (machine *management.AddMachineUserResponse, name, clientID, clientSecret string, err error) { + name = gofakeit.Username() + machine, err = s.Client.Mgmt.AddMachineUser(ctx, &management.AddMachineUserRequest{ Name: name, UserName: name, AccessTokenType: user.AccessTokenType_ACCESS_TOKEN_TYPE_JWT, }) if err != nil { - return "", "", "", err + return nil, "", "", "", err } secret, err := s.Client.Mgmt.GenerateMachineSecret(ctx, &management.GenerateMachineSecretRequest{ - UserId: user.GetUserId(), + UserId: machine.GetUserId(), }) if err != nil { - return "", "", "", err + return nil, "", "", "", err } - return user.GetUserId(), secret.GetClientId(), secret.GetClientSecret(), nil + return machine, name, secret.GetClientId(), secret.GetClientSecret(), nil } -func (s *Tester) CreateOIDCJWTProfileClient(ctx context.Context) (userID string, keyData []byte, err error) { - name := gofakeit.Username() - user, err := s.Client.Mgmt.AddMachineUser(ctx, &management.AddMachineUserRequest{ +func (s *Tester) CreateOIDCJWTProfileClient(ctx context.Context) (machine *management.AddMachineUserResponse, name string, keyData []byte, err error) { + name = gofakeit.Username() + machine, err = s.Client.Mgmt.AddMachineUser(ctx, &management.AddMachineUserRequest{ Name: name, UserName: name, AccessTokenType: user.AccessTokenType_ACCESS_TOKEN_TYPE_JWT, }) if err != nil { - return "", nil, err + return nil, "", nil, err } keyResp, err := s.Client.Mgmt.AddMachineKey(ctx, &management.AddMachineKeyRequest{ - UserId: user.GetUserId(), + UserId: machine.GetUserId(), Type: authn.KeyType_KEY_TYPE_JSON, ExpirationDate: timestamppb.New(time.Now().Add(time.Hour)), }) if err != nil { - return "", nil, err + return nil, "", nil, err } - return user.GetUserId(), keyResp.GetKeyDetails(), nil + return machine, name, keyResp.GetKeyDetails(), nil }