diff --git a/internal/api/oidc/auth_request.go b/internal/api/oidc/auth_request.go index 1abe59dc88..5d23a8bd98 100644 --- a/internal/api/oidc/auth_request.go +++ b/internal/api/oidc/auth_request.go @@ -471,7 +471,7 @@ func (s *Server) CreateTokenCallbackURL(ctx context.Context, req op.AuthRequest) if err != nil { return "", err } - resp, err := s.accessTokenResponseFromSession(ctx, client, session, state, client.client.ProjectID, client.client.ProjectRoleAssertion) + resp, err := s.accessTokenResponseFromSession(ctx, client, session, state, client.client.ProjectID, client.client.ProjectRoleAssertion, client.client.AccessTokenRoleAssertion, client.client.IDTokenRoleAssertion, client.client.IDTokenUserinfoAssertion) if err != nil { return "", err } @@ -563,7 +563,7 @@ func (s *Server) authResponseToken(authReq *AuthRequest, authorizer op.Authorize op.AuthRequestError(w, r, authReq, err, authorizer) return err } - resp, err := s.accessTokenResponseFromSession(ctx, client, session, authReq.GetState(), client.client.ProjectID, client.client.ProjectRoleAssertion) + resp, err := s.accessTokenResponseFromSession(ctx, client, session, authReq.GetState(), client.client.ProjectID, client.client.ProjectRoleAssertion, client.client.AccessTokenRoleAssertion, client.client.IDTokenRoleAssertion, client.client.IDTokenUserinfoAssertion) if err != nil { op.AuthRequestError(w, r, authReq, err, authorizer) return err diff --git a/internal/api/oidc/auth_request_integration_test.go b/internal/api/oidc/auth_request_integration_test.go index c36e06c6aa..e2e6ae2f7b 100644 --- a/internal/api/oidc/auth_request_integration_test.go +++ b/internal/api/oidc/auth_request_integration_test.go @@ -54,7 +54,7 @@ func TestOPStorage_CreateAccessToken_code(t *testing.T) { tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, false) - assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID) // callback on a succeeded request must fail linkResp, err = Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{ @@ -108,7 +108,7 @@ func TestOPStorage_CreateAccessToken_implicit(t *testing.T) { require.NoError(t, err) claims, err := rp.VerifyTokens[*oidc.IDTokenClaims](context.Background(), accessToken, idToken, provider.IDTokenVerifier()) require.NoError(t, err) - assertIDTokenClaims(t, claims, User.GetUserId(), armPasskey, startTime, changeTime) + assertIDTokenClaims(t, claims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID) // callback on a succeeded request must fail linkResp, err = Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{ @@ -143,7 +143,7 @@ func TestOPStorage_CreateAccessAndRefreshTokens_code(t *testing.T) { tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, true) - assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID) } func TestOPStorage_CreateAccessAndRefreshTokens_refresh(t *testing.T) { @@ -168,14 +168,14 @@ func TestOPStorage_CreateAccessAndRefreshTokens_refresh(t *testing.T) { tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, true) - assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID) // test actual refresh grant newTokens, err := refreshTokens(t, clientID, tokens.RefreshToken) require.NoError(t, err) assertTokens(t, newTokens, true) // auth time must still be the initial - assertIDTokenClaims(t, newTokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) + assertIDTokenClaims(t, newTokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID) // refresh with an old refresh_token must fail _, err = rp.RefreshTokens[*oidc.IDTokenClaims](CTX, provider, tokens.RefreshToken, "", "") @@ -204,7 +204,7 @@ func TestOPStorage_RevokeToken_access_token(t *testing.T) { tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, true) - assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID) // revoke access token err = rp.RevokeToken(CTX, provider, tokens.AccessToken, "access_token") @@ -247,7 +247,7 @@ func TestOPStorage_RevokeToken_access_token_invalid_token_hint_type(t *testing.T tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, true) - assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID) // revoke access token err = rp.RevokeToken(CTX, provider, tokens.AccessToken, "refresh_token") @@ -284,7 +284,7 @@ func TestOPStorage_RevokeToken_refresh_token(t *testing.T) { tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, true) - assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID) // revoke refresh token -> invalidates also access token err = rp.RevokeToken(CTX, provider, tokens.RefreshToken, "refresh_token") @@ -327,7 +327,7 @@ func TestOPStorage_RevokeToken_refresh_token_invalid_token_type_hint(t *testing. tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, true) - assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID) // revoke refresh token even with a wrong hint err = rp.RevokeToken(CTX, provider, tokens.RefreshToken, "access_token") @@ -362,7 +362,7 @@ func TestOPStorage_RevokeToken_invalid_client(t *testing.T) { tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, true) - assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID) // simulate second client (not part of the audience) trying to revoke the token otherClientID, _ := createClient(t) @@ -394,7 +394,7 @@ func TestOPStorage_TerminateSession(t *testing.T) { tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, false) - assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID) // userinfo must not fail _, err = rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.Subject, provider) @@ -431,7 +431,7 @@ func TestOPStorage_TerminateSession_refresh_grant(t *testing.T) { tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, true) - assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID) // userinfo must not fail _, err = rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.Subject, provider) @@ -475,7 +475,7 @@ func TestOPStorage_TerminateSession_empty_id_token_hint(t *testing.T) { tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, false) - assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID) postLogoutRedirect, err := rp.EndSession(CTX, provider, "", logoutRedirectURI, "state") require.NoError(t, err) @@ -530,8 +530,13 @@ func assertTokens(t *testing.T, tokens *oidc.Tokens[*oidc.IDTokenClaims], requir } } -func assertIDTokenClaims(t *testing.T, claims *oidc.IDTokenClaims, userID string, arm []string, sessionStart, sessionChange time.Time) { +func assertIDTokenClaims(t *testing.T, claims *oidc.IDTokenClaims, userID string, arm []string, sessionStart, sessionChange time.Time, sessionID string) { assert.Equal(t, userID, claims.Subject) assert.Equal(t, arm, claims.AuthenticationMethodsReferences) assertOIDCTimeRange(t, claims.AuthTime, sessionStart, sessionChange) + assert.Equal(t, sessionID, claims.SessionID) + assert.Empty(t, claims.Name) + assert.Empty(t, claims.GivenName) + assert.Empty(t, claims.FamilyName) + assert.Empty(t, claims.PreferredUsername) } diff --git a/internal/api/oidc/client_integration_test.go b/internal/api/oidc/client_integration_test.go index 58fdebef07..21d54a59dc 100644 --- a/internal/api/oidc/client_integration_test.go +++ b/internal/api/oidc/client_integration_test.go @@ -122,7 +122,7 @@ func TestServer_Introspect(t *testing.T) { tokens, err := exchangeTokens(t, app.GetClientId(), code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, true) - assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID) // test actual introspection introspection, err := rs.Introspect[*oidc.IntrospectionResponse](context.Background(), resourceServer, tokens.AccessToken) @@ -317,7 +317,7 @@ func TestServer_VerifyClient(t *testing.T) { } require.NoError(t, err) assertTokens(t, tokens, false) - assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID) }) } } diff --git a/internal/api/oidc/introspect.go b/internal/api/oidc/introspect.go index 41ffc3897f..b0881b6d65 100644 --- a/internal/api/oidc/introspect.go +++ b/internal/api/oidc/introspect.go @@ -11,6 +11,7 @@ import ( "github.com/zitadel/oidc/v3/pkg/op" "github.com/zitadel/zitadel/internal/api/authz" + "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/query" "github.com/zitadel/zitadel/internal/telemetry/tracing" "github.com/zitadel/zitadel/internal/zerrors" @@ -103,7 +104,14 @@ func (s *Server) Introspect(ctx context.Context, r *op.Request[op.IntrospectionR if err = validateIntrospectionAudience(token.audience, client.clientID, client.projectID); err != nil { return nil, err } - userInfo, err := s.userInfo(ctx, token.userID, token.scope, client.projectID, client.projectRoleAssertion, true) + userInfo, err := s.userInfo( + token.userID, + token.scope, + client.projectID, + client.projectRoleAssertion, + true, + true, + )(ctx, true, domain.TriggerTypePreUserinfoCreation) if err != nil { return nil, err } diff --git a/internal/api/oidc/oidc_integration_test.go b/internal/api/oidc/oidc_integration_test.go index c7a101afd3..6df8daa132 100644 --- a/internal/api/oidc/oidc_integration_test.go +++ b/internal/api/oidc/oidc_integration_test.go @@ -77,7 +77,7 @@ func Test_ZITADEL_API_missing_audience_scope(t *testing.T) { tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, false) - assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID) ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken)) @@ -142,7 +142,7 @@ func Test_ZITADEL_API_missing_mfa_2fa_setup(t *testing.T) { code := assertCodeResponse(t, linkResp.GetCallbackUrl()) tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) - assertIDTokenClaims(t, tokens.IDTokenClaims, userResp.GetUserId(), armPassword, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, userResp.GetUserId(), armPassword, startTime, changeTime, sessionID) ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken)) @@ -173,7 +173,7 @@ func Test_ZITADEL_API_missing_mfa_policy(t *testing.T) { code := assertCodeResponse(t, linkResp.GetCallbackUrl()) tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) - assertIDTokenClaims(t, tokens.IDTokenClaims, userID, armPassword, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, userID, armPassword, startTime, changeTime, sessionID) ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken)) @@ -227,7 +227,7 @@ func Test_ZITADEL_API_success(t *testing.T) { tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, false) - assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID) ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken)) @@ -261,7 +261,7 @@ func Test_ZITADEL_API_glob_redirects(t *testing.T) { tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, false) - assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID) ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken)) @@ -290,7 +290,7 @@ func Test_ZITADEL_API_inactive_access_token(t *testing.T) { tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, true) - assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID) // make sure token works ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken)) @@ -332,7 +332,7 @@ func Test_ZITADEL_API_terminated_session(t *testing.T) { tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, true) - assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID) // make sure token works ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken)) @@ -402,7 +402,7 @@ func Test_ZITADEL_API_terminated_session_user_disabled(t *testing.T) { tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, true) - assertIDTokenClaims(t, tokens.IDTokenClaims, disabledUser.GetUserId(), armPassword, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, disabledUser.GetUserId(), armPassword, startTime, changeTime, sessionID) // make sure token works ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken)) diff --git a/internal/api/oidc/token.go b/internal/api/oidc/token.go index c45eb98acb..be3a30ed73 100644 --- a/internal/api/oidc/token.go +++ b/internal/api/oidc/token.go @@ -29,8 +29,8 @@ In some cases step 1 till 3 are completely implemented in the command package, for example the v2 code exchange and refresh token. */ -func (s *Server) accessTokenResponseFromSession(ctx context.Context, client op.Client, session *command.OIDCSession, state, projectID string, projectRoleAssertion bool) (_ *oidc.AccessTokenResponse, err error) { - getUserInfo := s.getUserInfoOnce(session.UserID, projectID, projectRoleAssertion, session.Scope) +func (s *Server) accessTokenResponseFromSession(ctx context.Context, client op.Client, session *command.OIDCSession, state, projectID string, projectRoleAssertion, accessTokenRoleAssertion, idTokenRoleAssertion, userInfoAssertion bool) (_ *oidc.AccessTokenResponse, err error) { + getUserInfo := s.getUserInfo(session.UserID, projectID, projectRoleAssertion, userInfoAssertion, session.Scope) getSigner := s.getSignerOnce() resp := &oidc.AccessTokenResponse{ @@ -43,7 +43,7 @@ func (s *Server) accessTokenResponseFromSession(ctx context.Context, client op.C // If the session does not have a token ID, it is an implicit ID-Token only response. if session.TokenID != "" { if client.AccessTokenType() == op.AccessTokenTypeJWT { - resp.AccessToken, err = s.createJWT(ctx, client, session, getUserInfo, getSigner) + resp.AccessToken, err = s.createJWT(ctx, client, session, getUserInfo, accessTokenRoleAssertion, getSigner) } else { resp.AccessToken, err = op.CreateBearerToken(session.TokenID, session.UserID, s.opCrypto) } @@ -53,7 +53,7 @@ func (s *Server) accessTokenResponseFromSession(ctx context.Context, client op.C } if slices.Contains(session.Scope, oidc.ScopeOpenID) { - resp.IDToken, _, err = s.createIDToken(ctx, client, getUserInfo, getSigner, session.SessionID, resp.AccessToken, session.Audience, session.AuthMethods, session.AuthTime, session.Nonce, session.Actor) + resp.IDToken, _, err = s.createIDToken(ctx, client, getUserInfo, idTokenRoleAssertion, getSigner, session.SessionID, resp.AccessToken, session.Audience, session.AuthMethods, session.AuthTime, session.Nonce, session.Actor) } return resp, err } @@ -92,31 +92,22 @@ func (s *Server) getSignerOnce() signerFunc { } // userInfoFunc is a getter function that allows add-hoc retrieval of a user. -type userInfoFunc func(ctx context.Context) (*oidc.UserInfo, error) +type userInfoFunc func(ctx context.Context, roleAssertion bool, triggerType domain.TriggerType) (*oidc.UserInfo, error) -// getUserInfoOnce returns a function which retrieves userinfo from the database once. -// Repeated calls of the returned function return the same results. -func (s *Server) getUserInfoOnce(userID, projectID string, projectRoleAssertion bool, scope []string) userInfoFunc { - var ( - once sync.Once - userInfo *oidc.UserInfo - err error - ) - return func(ctx context.Context) (*oidc.UserInfo, error) { - once.Do(func() { - ctx, span := tracing.NewSpan(ctx) - defer func() { span.EndWithError(err) }() - userInfo, err = s.userInfo(ctx, userID, scope, projectID, projectRoleAssertion, false) - }) - return userInfo, err +// getUserInfo returns a function which retrieves userinfo from the database once. +// However, each time, role claims are asserted and also action flows will trigger. +func (s *Server) getUserInfo(userID, projectID string, projectRoleAssertion, userInfoAssertion bool, scope []string) userInfoFunc { + userInfo := s.userInfo(userID, scope, projectID, projectRoleAssertion, userInfoAssertion, false) + return func(ctx context.Context, roleAssertion bool, triggerType domain.TriggerType) (*oidc.UserInfo, error) { + return userInfo(ctx, roleAssertion, triggerType) } } -func (*Server) createIDToken(ctx context.Context, client op.Client, getUserInfo userInfoFunc, getSigningKey signerFunc, sessionID, accessToken string, audience []string, authMethods []domain.UserAuthMethodType, authTime time.Time, nonce string, actor *domain.TokenActor) (idToken string, exp uint64, err error) { +func (*Server) createIDToken(ctx context.Context, client op.Client, getUserInfo userInfoFunc, roleAssertion bool, getSigningKey signerFunc, sessionID, accessToken string, audience []string, authMethods []domain.UserAuthMethodType, authTime time.Time, nonce string, actor *domain.TokenActor) (idToken string, exp uint64, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - userInfo, err := getUserInfo(ctx) + userInfo, err := getUserInfo(ctx, roleAssertion, domain.TriggerTypePreUserinfoCreation) if err != nil { return "", 0, err } @@ -156,11 +147,11 @@ func timeToOIDCExpiresIn(exp time.Time) uint64 { return uint64(time.Until(exp) / time.Second) } -func (*Server) createJWT(ctx context.Context, client op.Client, session *command.OIDCSession, getUserInfo userInfoFunc, getSigner signerFunc) (_ string, err error) { +func (s *Server) createJWT(ctx context.Context, client op.Client, session *command.OIDCSession, getUserInfo userInfoFunc, assertRoles bool, getSigner signerFunc) (_ string, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - userInfo, err := getUserInfo(ctx) + userInfo, err := getUserInfo(ctx, assertRoles, domain.TriggerTypePreAccessTokenCreation) if err != nil { return "", err } diff --git a/internal/api/oidc/token_client_credentials.go b/internal/api/oidc/token_client_credentials.go index e0cb29770b..4b3bf20acd 100644 --- a/internal/api/oidc/token_client_credentials.go +++ b/internal/api/oidc/token_client_credentials.go @@ -47,5 +47,5 @@ func (s *Server) ClientCredentialsExchange(ctx context.Context, r *op.ClientRequ false, ) - return response(s.accessTokenResponseFromSession(ctx, client, session, "", "", false)) + return response(s.accessTokenResponseFromSession(ctx, client, session, "", "", false, true, false, false)) } diff --git a/internal/api/oidc/token_client_credentials_integration_test.go b/internal/api/oidc/token_client_credentials_integration_test.go index 3517438596..21a1c4de75 100644 --- a/internal/api/oidc/token_client_credentials_integration_test.go +++ b/internal/api/oidc/token_client_credentials_integration_test.go @@ -4,6 +4,7 @@ package oidc_test import ( "testing" + "time" "github.com/brianvoe/gofakeit/v6" "github.com/stretchr/testify/assert" @@ -18,10 +19,13 @@ import ( ) func TestServer_ClientCredentialsExchange(t *testing.T) { - userID, clientID, clientSecret, err := Tester.CreateOIDCCredentialsClient(CTX) + machine, name, clientID, clientSecret, err := Tester.CreateOIDCCredentialsClient(CTX) require.NoError(t, err) type claims struct { + name string + username string + updated time.Time resourceOwnerID any resourceOwnerName any resourceOwnerPrimaryDomain any @@ -78,6 +82,17 @@ func TestServer_ClientCredentialsExchange(t *testing.T) { clientSecret: clientSecret, scope: []string{oidc.ScopeOpenID}, }, + { + name: "openid, profile, email", + clientID: clientID, + clientSecret: clientSecret, + scope: []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail}, + wantClaims: claims{ + name: name, + username: name, + updated: machine.GetDetails().GetChangeDate().AsTime(), + }, + }, { name: "org id and domain scope", clientID: clientID, @@ -132,12 +147,20 @@ func TestServer_ClientCredentialsExchange(t *testing.T) { } require.NoError(t, err) require.NotNil(t, tokens) - userinfo, err := rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, oidc.BearerToken, userID, provider) + userinfo, err := rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, oidc.BearerToken, machine.GetUserId(), provider) require.NoError(t, err) assert.Equal(t, tt.wantClaims.resourceOwnerID, userinfo.Claims[oidc_api.ClaimResourceOwnerID]) assert.Equal(t, tt.wantClaims.resourceOwnerName, userinfo.Claims[oidc_api.ClaimResourceOwnerName]) assert.Equal(t, tt.wantClaims.resourceOwnerPrimaryDomain, userinfo.Claims[oidc_api.ClaimResourceOwnerPrimaryDomain]) assert.Equal(t, tt.wantClaims.orgDomain, userinfo.Claims[domain.OrgDomainPrimaryClaim]) + assert.Equal(t, tt.wantClaims.name, userinfo.Name) + assert.Equal(t, tt.wantClaims.username, userinfo.PreferredUsername) + assertOIDCTime(t, userinfo.UpdatedAt, tt.wantClaims.updated) + assert.Empty(t, userinfo.UserInfoProfile.FamilyName) + assert.Empty(t, userinfo.UserInfoProfile.GivenName) + assert.Empty(t, userinfo.UserInfoEmail) + assert.Empty(t, userinfo.UserInfoPhone) + assert.Empty(t, userinfo.Address) }) } } diff --git a/internal/api/oidc/token_code.go b/internal/api/oidc/token_code.go index 85aa847579..b7ccd1d22f 100644 --- a/internal/api/oidc/token_code.go +++ b/internal/api/oidc/token_code.go @@ -49,7 +49,7 @@ func (s *Server) CodeExchange(ctx context.Context, r *op.ClientRequest[oidc.Acce if err != nil { return nil, err } - return response(s.accessTokenResponseFromSession(ctx, client, session, state, client.client.ProjectID, client.client.ProjectRoleAssertion)) + return response(s.accessTokenResponseFromSession(ctx, client, session, state, client.client.ProjectID, client.client.ProjectRoleAssertion, client.client.AccessTokenRoleAssertion, client.client.IDTokenRoleAssertion, client.client.IDTokenUserinfoAssertion)) } // codeExchangeV1 creates a v2 token from a v1 auth request. diff --git a/internal/api/oidc/token_device.go b/internal/api/oidc/token_device.go index c70fa9c1e9..b574af1260 100644 --- a/internal/api/oidc/token_device.go +++ b/internal/api/oidc/token_device.go @@ -26,7 +26,7 @@ func (s *Server) DeviceToken(ctx context.Context, r *op.ClientRequest[oidc.Devic } session, err := s.command.CreateOIDCSessionFromDeviceAuth(ctx, r.Data.DeviceCode) if err == nil { - return response(s.accessTokenResponseFromSession(ctx, client, session, "", client.client.ProjectID, client.client.ProjectRoleAssertion)) + return response(s.accessTokenResponseFromSession(ctx, client, session, "", client.client.ProjectID, client.client.ProjectRoleAssertion, client.client.AccessTokenRoleAssertion, client.client.IDTokenRoleAssertion, client.client.IDTokenUserinfoAssertion)) } if errors.Is(err, context.DeadlineExceeded) { return nil, oidc.ErrSlowDown().WithParent(err) diff --git a/internal/api/oidc/token_exchange.go b/internal/api/oidc/token_exchange.go index 31b1a37db3..bd19e565cc 100644 --- a/internal/api/oidc/token_exchange.go +++ b/internal/api/oidc/token_exchange.go @@ -218,7 +218,7 @@ func validateTokenExchangeAudience(requestedAudience, subjectAudience, actorAudi // Both tokens may point to the same object (subjectToken) in case of a regular Token Exchange. // When the subject and actor Tokens point to different objects, the new tokens will be for impersonation / delegation. func (s *Server) createExchangeTokens(ctx context.Context, tokenType oidc.TokenType, client *Client, subjectToken, actorToken *exchangeToken, audience, scopes []string) (_ *oidc.TokenExchangeResponse, err error) { - getUserInfo := s.getUserInfoOnce(subjectToken.userID, client.client.ProjectID, client.client.ProjectRoleAssertion, scopes) + getUserInfo := s.getUserInfo(subjectToken.userID, client.client.ProjectID, client.client.ProjectRoleAssertion, client.IDTokenUserinfoClaimsAssertion(), scopes) getSigner := s.getSignerOnce() resp := &oidc.TokenExchangeResponse{ @@ -240,12 +240,12 @@ func (s *Server) createExchangeTokens(ctx context.Context, tokenType oidc.TokenT resp.IssuedTokenType = oidc.AccessTokenType case oidc.JWTTokenType: - resp.AccessToken, resp.RefreshToken, resp.ExpiresIn, err = s.createExchangeJWT(ctx, client, getUserInfo, getSigner, subjectToken.userID, subjectToken.resourceOwner, audience, scopes, actorToken.authMethods, actorToken.authTime, subjectToken.preferredLanguage, reason, actor) + resp.AccessToken, resp.RefreshToken, resp.ExpiresIn, err = s.createExchangeJWT(ctx, client, getUserInfo, client.client.AccessTokenRoleAssertion, getSigner, subjectToken.userID, subjectToken.resourceOwner, audience, scopes, actorToken.authMethods, actorToken.authTime, subjectToken.preferredLanguage, reason, actor) resp.TokenType = oidc.BearerToken resp.IssuedTokenType = oidc.JWTTokenType case oidc.IDTokenType: - resp.AccessToken, resp.ExpiresIn, err = s.createIDToken(ctx, client, getUserInfo, getSigner, "", resp.AccessToken, audience, actorToken.authMethods, actorToken.authTime, "", actor) + resp.AccessToken, resp.ExpiresIn, err = s.createIDToken(ctx, client, getUserInfo, client.client.IDTokenRoleAssertion, getSigner, "", resp.AccessToken, audience, actorToken.authMethods, actorToken.authTime, "", actor) resp.TokenType = TokenTypeNA resp.IssuedTokenType = oidc.IDTokenType @@ -259,7 +259,7 @@ func (s *Server) createExchangeTokens(ctx context.Context, tokenType oidc.TokenT } if slices.Contains(scopes, oidc.ScopeOpenID) && tokenType != oidc.IDTokenType { - resp.IDToken, _, err = s.createIDToken(ctx, client, getUserInfo, getSigner, sessionID, resp.AccessToken, audience, actorToken.authMethods, actorToken.authTime, "", actor) + resp.IDToken, _, err = s.createIDToken(ctx, client, getUserInfo, client.client.IDTokenRoleAssertion, getSigner, sessionID, resp.AccessToken, audience, actorToken.authMethods, actorToken.authTime, "", actor) if err != nil { return nil, err } @@ -313,6 +313,7 @@ func (s *Server) createExchangeJWT( ctx context.Context, client *Client, getUserInfo userInfoFunc, + roleAssertion bool, getSigner signerFunc, userID, resourceOwner string, @@ -342,7 +343,7 @@ func (s *Server) createExchangeJWT( actor, slices.Contains(scope, oidc.ScopeOfflineAccess), ) - accessToken, err = s.createJWT(ctx, client, session, getUserInfo, getSigner) + accessToken, err = s.createJWT(ctx, client, session, getUserInfo, roleAssertion, getSigner) if err != nil { return "", "", 0, err } diff --git a/internal/api/oidc/token_jwt_profile.go b/internal/api/oidc/token_jwt_profile.go index b23a24f77f..fc0c31a6eb 100644 --- a/internal/api/oidc/token_jwt_profile.go +++ b/internal/api/oidc/token_jwt_profile.go @@ -54,7 +54,7 @@ func (s *Server) JWTProfile(ctx context.Context, r *op.Request[oidc.JWTProfileGr nil, false, ) - return response(s.accessTokenResponseFromSession(ctx, client, session, "", "", false)) + return response(s.accessTokenResponseFromSession(ctx, client, session, "", "", false, true, false, false)) } func (s *Server) verifyJWTProfile(ctx context.Context, req *oidc.JWTProfileGrantRequest) (user *query.User, tokenRequest *oidc.JWTTokenRequest, err error) { diff --git a/internal/api/oidc/token_jwt_profile_integration_test.go b/internal/api/oidc/token_jwt_profile_integration_test.go index b80ea09bcd..0ad8d76da2 100644 --- a/internal/api/oidc/token_jwt_profile_integration_test.go +++ b/internal/api/oidc/token_jwt_profile_integration_test.go @@ -4,6 +4,7 @@ package oidc_test import ( "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -16,10 +17,13 @@ import ( ) func TestServer_JWTProfile(t *testing.T) { - userID, keyData, err := Tester.CreateOIDCJWTProfileClient(CTX) + user, name, keyData, err := Tester.CreateOIDCJWTProfileClient(CTX) require.NoError(t, err) type claims struct { + name string + username string + updated time.Time resourceOwnerID any resourceOwnerName any resourceOwnerPrimaryDomain any @@ -37,6 +41,16 @@ func TestServer_JWTProfile(t *testing.T) { keyData: keyData, scope: []string{oidc.ScopeOpenID}, }, + { + name: "openid, profile, email", + keyData: keyData, + scope: []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail}, + wantClaims: claims{ + name: name, + username: name, + updated: user.GetDetails().GetChangeDate().AsTime(), + }, + }, { name: "org id and domain scope", keyData: keyData, @@ -92,12 +106,20 @@ func TestServer_JWTProfile(t *testing.T) { provider, err := rp.NewRelyingPartyOIDC(CTX, Tester.OIDCIssuer(), "", "", redirectURI, tt.scope) require.NoError(t, err) - userinfo, err := rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, oidc.BearerToken, userID, provider) + userinfo, err := rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, oidc.BearerToken, user.GetUserId(), provider) require.NoError(t, err) assert.Equal(t, tt.wantClaims.resourceOwnerID, userinfo.Claims[oidc_api.ClaimResourceOwnerID]) assert.Equal(t, tt.wantClaims.resourceOwnerName, userinfo.Claims[oidc_api.ClaimResourceOwnerName]) assert.Equal(t, tt.wantClaims.resourceOwnerPrimaryDomain, userinfo.Claims[oidc_api.ClaimResourceOwnerPrimaryDomain]) assert.Equal(t, tt.wantClaims.orgDomain, userinfo.Claims[domain.OrgDomainPrimaryClaim]) + assert.Equal(t, tt.wantClaims.name, userinfo.Name) + assert.Equal(t, tt.wantClaims.username, userinfo.PreferredUsername) + assertOIDCTime(t, userinfo.UpdatedAt, tt.wantClaims.updated) + assert.Empty(t, userinfo.UserInfoProfile.FamilyName) + assert.Empty(t, userinfo.UserInfoProfile.GivenName) + assert.Empty(t, userinfo.UserInfoEmail) + assert.Empty(t, userinfo.UserInfoPhone) + assert.Empty(t, userinfo.Address) }) } } diff --git a/internal/api/oidc/token_refresh.go b/internal/api/oidc/token_refresh.go index 66a8b8a263..1dcce2879a 100644 --- a/internal/api/oidc/token_refresh.go +++ b/internal/api/oidc/token_refresh.go @@ -28,7 +28,7 @@ func (s *Server) RefreshToken(ctx context.Context, r *op.ClientRequest[oidc.Refr session, err := s.command.ExchangeOIDCSessionRefreshAndAccessToken(ctx, r.Data.RefreshToken, r.Data.Scopes, refreshTokenComplianceChecker()) if err == nil { - return response(s.accessTokenResponseFromSession(ctx, client, session, "", client.client.ProjectID, client.client.ProjectRoleAssertion)) + return response(s.accessTokenResponseFromSession(ctx, client, session, "", client.client.ProjectID, client.client.ProjectRoleAssertion, client.client.AccessTokenRoleAssertion, client.client.IDTokenRoleAssertion, client.client.IDTokenUserinfoAssertion)) } else if errors.Is(err, zerrors.ThrowPreconditionFailed(nil, "OIDCS-JOI23", "Errors.OIDCSession.RefreshTokenInvalid")) { // We try again for v1 tokens when we encountered specific parsing error return s.refreshTokenV1(ctx, client, r) @@ -78,7 +78,7 @@ func (s *Server) refreshTokenV1(ctx context.Context, client *Client, r *op.Clien return nil, err } - return response(s.accessTokenResponseFromSession(ctx, client, session, "", client.client.ProjectID, client.client.ProjectRoleAssertion)) + return response(s.accessTokenResponseFromSession(ctx, client, session, "", client.client.ProjectID, client.client.ProjectRoleAssertion, client.client.AccessTokenRoleAssertion, client.client.IDTokenRoleAssertion, client.client.IDTokenUserinfoAssertion)) } // refreshTokenComplianceChecker validates that the requested scope is a subset of the original auth request scope. diff --git a/internal/api/oidc/userinfo.go b/internal/api/oidc/userinfo.go index c21b746c49..7d352457bc 100644 --- a/internal/api/oidc/userinfo.go +++ b/internal/api/oidc/userinfo.go @@ -8,6 +8,7 @@ import ( "net/http" "slices" "strings" + "sync" "github.com/dop251/goja" "github.com/zitadel/logging" @@ -55,7 +56,14 @@ func (s *Server) UserInfo(ctx context.Context, r *op.Request[oidc.UserInfoReques } } - userInfo, err := s.userInfo(ctx, token.userID, token.scope, projectID, assertion, false) + userInfo, err := s.userInfo( + token.userID, + token.scope, + projectID, + assertion, + true, + false, + )(ctx, true, domain.TriggerTypePreUserinfoCreation) if err != nil { return nil, err } @@ -66,24 +74,44 @@ func (s *Server) UserInfo(ctx context.Context, r *op.Request[oidc.UserInfoReques // The returned UserInfo contains standard and reserved claims, documented // here: https://zitadel.com/docs/apis/openidoauth/claims. // +// User information is only retrieved once from the database. +// However, each time, role claims are asserted and also action flows will trigger. +// // projectID is an optional parameter which defines the default audience when there are any (or all) role claims requested. // projectRoleAssertion sets the default of returning all project roles, only if no specific roles were requested in the scope. +// roleAssertion decides whether the roles will be returned (in the token or response) +// userInfoAssertion decides whether the user information (profile data like name, email, ...) are returned // // currentProjectOnly can be set to use the current project ID only and ignore the audience from the scope. // It should be set in cases where the client doesn't need to know roles outside its own project, // for example an introspection client. -func (s *Server) userInfo(ctx context.Context, userID string, scope []string, projectID string, projectRoleAssertion, currentProjectOnly bool) (_ *oidc.UserInfo, err error) { - ctx, span := tracing.NewSpan(ctx) - defer func() { span.EndWithError(err) }() +func (s *Server) userInfo( + userID string, + scope []string, + projectID string, + projectRoleAssertion, userInfoAssertion, currentProjectOnly bool, +) func(ctx context.Context, roleAssertion bool, triggerType domain.TriggerType) (_ *oidc.UserInfo, err error) { + var ( + once sync.Once + userInfo *oidc.UserInfo + qu *query.OIDCUserInfo + roleAudience, requestedRoles []string + ) + return func(ctx context.Context, roleAssertion bool, triggerType domain.TriggerType) (_ *oidc.UserInfo, err error) { + once.Do(func() { + ctx, span := tracing.NewSpan(ctx) + defer func() { span.EndWithError(err) }() - roleAudience, requestedRoles := prepareRoles(ctx, scope, projectID, projectRoleAssertion, currentProjectOnly) - qu, err := s.query.GetOIDCUserInfo(ctx, userID, roleAudience) - if err != nil { - return nil, err + roleAudience, requestedRoles = prepareRoles(ctx, scope, projectID, projectRoleAssertion, currentProjectOnly) + qu, err = s.query.GetOIDCUserInfo(ctx, userID, roleAudience) + if err != nil { + return + } + userInfo = userInfoToOIDC(qu, userInfoAssertion, scope, s.assetAPIPrefix(ctx)) + }) + userInfoWithRoles := assertRoles(projectID, qu, roleAudience, requestedRoles, roleAssertion, userInfo) + return userInfoWithRoles, s.userinfoFlows(ctx, qu, userInfoWithRoles, triggerType) } - - userInfo := userInfoToOIDC(projectID, qu, scope, roleAudience, requestedRoles, s.assetAPIPrefix(ctx)) - return userInfo, s.userinfoFlows(ctx, qu, userInfo) } // prepareRoles scans the requested scopes and builds the requested roles @@ -120,20 +148,32 @@ func prepareRoles(ctx context.Context, scope []string, projectID string, project return roleAudience, requestedRoles } -func userInfoToOIDC(projectID string, user *query.OIDCUserInfo, scope, roleAudience, requestedRoles []string, assetPrefix string) *oidc.UserInfo { +func userInfoToOIDC(user *query.OIDCUserInfo, userInfoAssertion bool, scope []string, assetPrefix string) *oidc.UserInfo { out := new(oidc.UserInfo) for _, s := range scope { switch s { case oidc.ScopeOpenID: out.Subject = user.User.ID case oidc.ScopeEmail: + if !userInfoAssertion { + continue + } out.UserInfoEmail = userInfoEmailToOIDC(user.User) case oidc.ScopeProfile: + if !userInfoAssertion { + continue + } out.UserInfoProfile = userInfoProfileToOidc(user.User, assetPrefix) case oidc.ScopePhone: + if !userInfoAssertion { + continue + } out.UserInfoPhone = userInfoPhoneToOIDC(user.User) case oidc.ScopeAddress: - //TODO: handle address for human users as soon as implemented + if !userInfoAssertion { + continue + } + // TODO: handle address for human users as soon as implemented case ScopeUserMetaData: setUserInfoMetadata(user.Metadata, out) case ScopeResourceOwner: @@ -148,12 +188,19 @@ func userInfoToOIDC(projectID string, user *query.OIDCUserInfo, scope, roleAudie } } } + return out +} +func assertRoles(projectID string, user *query.OIDCUserInfo, roleAudience, requestedRoles []string, assertion bool, info *oidc.UserInfo) *oidc.UserInfo { + if !assertion { + return info + } + userInfo := *info // prevent returning obtained grants if none where requested if (projectID != "" && len(requestedRoles) > 0) || len(roleAudience) > 0 { - setUserInfoRoleClaims(out, newProjectRoles(projectID, user.UserGrants, requestedRoles)) + setUserInfoRoleClaims(&userInfo, newProjectRoles(projectID, user.UserGrants, requestedRoles)) } - return out + return &userInfo } func userInfoEmailToOIDC(user *query.User) oidc.UserInfoEmail { @@ -230,11 +277,11 @@ func setUserInfoRoleClaims(userInfo *oidc.UserInfo, roles *projectsRoles) { } } -func (s *Server) userinfoFlows(ctx context.Context, qu *query.OIDCUserInfo, userInfo *oidc.UserInfo) (err error) { +func (s *Server) userinfoFlows(ctx context.Context, qu *query.OIDCUserInfo, userInfo *oidc.UserInfo, triggerType domain.TriggerType) (err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - queriedActions, err := s.query.GetActiveActionsByFlowAndTriggerType(ctx, domain.FlowTypeCustomiseToken, domain.TriggerTypePreUserinfoCreation, qu.User.ResourceOwner) + queriedActions, err := s.query.GetActiveActionsByFlowAndTriggerType(ctx, domain.FlowTypeCustomiseToken, triggerType, qu.User.ResourceOwner) if err != nil { return err } diff --git a/internal/api/oidc/userinfo_integration_test.go b/internal/api/oidc/userinfo_integration_test.go index 22e688ff4b..7f39ed38ba 100644 --- a/internal/api/oidc/userinfo_integration_test.go +++ b/internal/api/oidc/userinfo_integration_test.go @@ -231,9 +231,9 @@ func TestServer_UserInfo_Issue6662(t *testing.T) { project, err := Tester.CreateProject(CTX) projectID := project.GetId() require.NoError(t, err) - userID, clientID, clientSecret, err := Tester.CreateOIDCCredentialsClient(CTX) + user, _, clientID, clientSecret, err := Tester.CreateOIDCCredentialsClient(CTX) require.NoError(t, err) - addProjectRolesGrants(t, userID, projectID, roleFoo, roleBar) + addProjectRolesGrants(t, user.GetUserId(), projectID, roleFoo, roleBar) scope := []string{oidc.ScopeProfile, oidc.ScopeOpenID, oidc.ScopeEmail, oidc.ScopeOfflineAccess, oidc_api.ScopeProjectRolePrefix + roleFoo, @@ -245,7 +245,7 @@ func TestServer_UserInfo_Issue6662(t *testing.T) { tokens, err := rp.ClientCredentials(CTX, provider, nil) require.NoError(t, err) - userinfo, err := rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, tokens.TokenType, userID, provider) + userinfo, err := rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, tokens.TokenType, user.GetUserId(), provider) require.NoError(t, err) assertProjectRoleClaims(t, projectID, userinfo.Claims, false, roleFoo) } @@ -291,7 +291,7 @@ func getTokens(t *testing.T, clientID string, scope []string) *oidc.Tokens[*oidc tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, true) - assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID) return tokens } diff --git a/internal/api/oidc/userinfo_test.go b/internal/api/oidc/userinfo_test.go index 65354a4040..741f7eed36 100644 --- a/internal/api/oidc/userinfo_test.go +++ b/internal/api/oidc/userinfo_test.go @@ -3,7 +3,6 @@ package oidc import ( "context" "encoding/base64" - "fmt" "testing" "time" @@ -267,11 +266,9 @@ func Test_userInfoToOIDC(t *testing.T) { } type args struct { - projectID string - user *query.OIDCUserInfo - scope []string - roleAudience []string - requestedRoles []string + user *query.OIDCUserInfo + userInfoAssertion bool + scope []string } tests := []struct { name string @@ -281,25 +278,22 @@ func Test_userInfoToOIDC(t *testing.T) { { name: "human, empty", args: args{ - projectID: "project1", - user: humanUserInfo, + user: humanUserInfo, }, want: &oidc.UserInfo{}, }, { name: "machine, empty", args: args{ - projectID: "project1", - user: machineUserInfo, + user: machineUserInfo, }, want: &oidc.UserInfo{}, }, { name: "human, scope openid", args: args{ - projectID: "project1", - user: humanUserInfo, - scope: []string{oidc.ScopeOpenID}, + user: humanUserInfo, + scope: []string{oidc.ScopeOpenID}, }, want: &oidc.UserInfo{ Subject: "human1", @@ -308,20 +302,19 @@ func Test_userInfoToOIDC(t *testing.T) { { name: "machine, scope openid", args: args{ - projectID: "project1", - user: machineUserInfo, - scope: []string{oidc.ScopeOpenID}, + user: machineUserInfo, + scope: []string{oidc.ScopeOpenID}, }, want: &oidc.UserInfo{ Subject: "machine1", }, }, { - name: "human, scope email", + name: "human, scope email, profileInfoAssertion", args: args{ - projectID: "project1", - user: humanUserInfo, - scope: []string{oidc.ScopeEmail}, + user: humanUserInfo, + userInfoAssertion: true, + scope: []string{oidc.ScopeEmail}, }, want: &oidc.UserInfo{ UserInfoEmail: oidc.UserInfoEmail{ @@ -331,22 +324,29 @@ func Test_userInfoToOIDC(t *testing.T) { }, }, { - name: "machine, scope email", + name: "human, scope email", args: args{ - projectID: "project1", - user: machineUserInfo, - scope: []string{oidc.ScopeEmail}, + user: humanUserInfo, + scope: []string{oidc.ScopeEmail}, + }, + want: &oidc.UserInfo{}, + }, + { + name: "machine, scope email, profileInfoAssertion", + args: args{ + user: machineUserInfo, + scope: []string{oidc.ScopeEmail}, }, want: &oidc.UserInfo{ UserInfoEmail: oidc.UserInfoEmail{}, }, }, { - name: "human, scope profile", + name: "human, scope profile, profileInfoAssertion", args: args{ - projectID: "project1", - user: humanUserInfo, - scope: []string{oidc.ScopeProfile}, + user: humanUserInfo, + userInfoAssertion: true, + scope: []string{oidc.ScopeProfile}, }, want: &oidc.UserInfo{ UserInfoProfile: oidc.UserInfoProfile{ @@ -363,11 +363,11 @@ func Test_userInfoToOIDC(t *testing.T) { }, }, { - name: "machine, scope profile", + name: "machine, scope profile, profileInfoAssertion", args: args{ - projectID: "project1", - user: machineUserInfo, - scope: []string{oidc.ScopeProfile}, + user: machineUserInfo, + userInfoAssertion: true, + scope: []string{oidc.ScopeProfile}, }, want: &oidc.UserInfo{ UserInfoProfile: oidc.UserInfoProfile{ @@ -378,11 +378,19 @@ func Test_userInfoToOIDC(t *testing.T) { }, }, { - name: "human, scope phone", + name: "machine, scope profile", args: args{ - projectID: "project1", - user: humanUserInfo, - scope: []string{oidc.ScopePhone}, + user: machineUserInfo, + scope: []string{oidc.ScopeProfile}, + }, + want: &oidc.UserInfo{}, + }, + { + name: "human, scope phone, profileInfoAssertion", + args: args{ + user: humanUserInfo, + userInfoAssertion: true, + scope: []string{oidc.ScopePhone}, }, want: &oidc.UserInfo{ UserInfoPhone: oidc.UserInfoPhone{ @@ -391,12 +399,19 @@ func Test_userInfoToOIDC(t *testing.T) { }, }, }, + { + name: "human, scope phone", + args: args{ + user: humanUserInfo, + scope: []string{oidc.ScopePhone}, + }, + want: &oidc.UserInfo{}, + }, { name: "machine, scope phone", args: args{ - projectID: "project1", - user: machineUserInfo, - scope: []string{oidc.ScopePhone}, + user: machineUserInfo, + scope: []string{oidc.ScopePhone}, }, want: &oidc.UserInfo{ UserInfoPhone: oidc.UserInfoPhone{}, @@ -405,9 +420,8 @@ func Test_userInfoToOIDC(t *testing.T) { { name: "human, scope metadata", args: args{ - projectID: "project1", - user: humanUserInfo, - scope: []string{ScopeUserMetaData}, + user: humanUserInfo, + scope: []string{ScopeUserMetaData}, }, want: &oidc.UserInfo{ Claims: map[string]any{ @@ -421,18 +435,16 @@ func Test_userInfoToOIDC(t *testing.T) { { name: "machine, scope metadata, none found", args: args{ - projectID: "project1", - user: machineUserInfo, - scope: []string{ScopeUserMetaData}, + user: machineUserInfo, + scope: []string{ScopeUserMetaData}, }, want: &oidc.UserInfo{}, }, { name: "machine, scope resource owner", args: args{ - projectID: "project1", - user: machineUserInfo, - scope: []string{ScopeResourceOwner}, + user: machineUserInfo, + scope: []string{ScopeResourceOwner}, }, want: &oidc.UserInfo{ Claims: map[string]any{ @@ -445,9 +457,8 @@ func Test_userInfoToOIDC(t *testing.T) { { name: "human, scope org primary domain prefix", args: args{ - projectID: "project1", - user: humanUserInfo, - scope: []string{domain.OrgDomainPrimaryScope + "foo.com"}, + user: humanUserInfo, + scope: []string{domain.OrgDomainPrimaryScope + "foo.com"}, }, want: &oidc.UserInfo{ Claims: map[string]any{ @@ -458,9 +469,8 @@ func Test_userInfoToOIDC(t *testing.T) { { name: "machine, scope org id", args: args{ - projectID: "project1", - user: machineUserInfo, - scope: []string{domain.OrgIDScope + "orgID"}, + user: machineUserInfo, + scope: []string{domain.OrgIDScope + "orgID"}, }, want: &oidc.UserInfo{ Claims: map[string]any{ @@ -471,50 +481,11 @@ func Test_userInfoToOIDC(t *testing.T) { }, }, }, - { - name: "human, roleAudience", - args: args{ - projectID: "project1", - user: humanUserInfo, - roleAudience: []string{"project1"}, - }, - want: &oidc.UserInfo{ - Claims: map[string]any{ - ClaimProjectRoles: projectRoles{ - "role1": {"orgID": "orgDomain"}, - "role2": {"orgID": "orgDomain"}, - }, - fmt.Sprintf(ClaimProjectRolesFormat, "project1"): projectRoles{ - "role1": {"orgID": "orgDomain"}, - "role2": {"orgID": "orgDomain"}, - }, - }, - }, - }, - { - name: "human, requested roles", - args: args{ - projectID: "project1", - user: humanUserInfo, - roleAudience: []string{"project1"}, - requestedRoles: []string{"role2"}, - }, - want: &oidc.UserInfo{ - Claims: map[string]any{ - ClaimProjectRoles: projectRoles{ - "role2": {"orgID": "orgDomain"}, - }, - fmt.Sprintf(ClaimProjectRolesFormat, "project1"): projectRoles{ - "role2": {"orgID": "orgDomain"}, - }, - }, - }, - }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { assetPrefix := "https://foo.com/assets" - got := userInfoToOIDC(tt.args.projectID, tt.args.user, tt.args.scope, tt.args.roleAudience, tt.args.requestedRoles, assetPrefix) + got := userInfoToOIDC(tt.args.user, tt.args.userInfoAssertion, tt.args.scope, assetPrefix) assert.Equal(t, tt.want, got) }) } diff --git a/internal/integration/oidc.go b/internal/integration/oidc.go index 3ba655c65b..2d7d6a105e 100644 --- a/internal/integration/oidc.go +++ b/internal/integration/oidc.go @@ -281,42 +281,42 @@ func CheckRedirect(req *http.Request) (*url.URL, error) { return resp.Location() } -func (s *Tester) CreateOIDCCredentialsClient(ctx context.Context) (userID, clientID, clientSecret string, err error) { - name := gofakeit.Username() - user, err := s.Client.Mgmt.AddMachineUser(ctx, &management.AddMachineUserRequest{ +func (s *Tester) CreateOIDCCredentialsClient(ctx context.Context) (machine *management.AddMachineUserResponse, name, clientID, clientSecret string, err error) { + name = gofakeit.Username() + machine, err = s.Client.Mgmt.AddMachineUser(ctx, &management.AddMachineUserRequest{ Name: name, UserName: name, AccessTokenType: user.AccessTokenType_ACCESS_TOKEN_TYPE_JWT, }) if err != nil { - return "", "", "", err + return nil, "", "", "", err } secret, err := s.Client.Mgmt.GenerateMachineSecret(ctx, &management.GenerateMachineSecretRequest{ - UserId: user.GetUserId(), + UserId: machine.GetUserId(), }) if err != nil { - return "", "", "", err + return nil, "", "", "", err } - return user.GetUserId(), secret.GetClientId(), secret.GetClientSecret(), nil + return machine, name, secret.GetClientId(), secret.GetClientSecret(), nil } -func (s *Tester) CreateOIDCJWTProfileClient(ctx context.Context) (userID string, keyData []byte, err error) { - name := gofakeit.Username() - user, err := s.Client.Mgmt.AddMachineUser(ctx, &management.AddMachineUserRequest{ +func (s *Tester) CreateOIDCJWTProfileClient(ctx context.Context) (machine *management.AddMachineUserResponse, name string, keyData []byte, err error) { + name = gofakeit.Username() + machine, err = s.Client.Mgmt.AddMachineUser(ctx, &management.AddMachineUserRequest{ Name: name, UserName: name, AccessTokenType: user.AccessTokenType_ACCESS_TOKEN_TYPE_JWT, }) if err != nil { - return "", nil, err + return nil, "", nil, err } keyResp, err := s.Client.Mgmt.AddMachineKey(ctx, &management.AddMachineKeyRequest{ - UserId: user.GetUserId(), + UserId: machine.GetUserId(), Type: authn.KeyType_KEY_TYPE_JSON, ExpirationDate: timestamppb.New(time.Now().Add(time.Hour)), }) if err != nil { - return "", nil, err + return nil, "", nil, err } - return user.GetUserId(), keyResp.GetKeyDetails(), nil + return machine, name, keyResp.GetKeyDetails(), nil }