feat(oidc): optimize the userinfo endpoint (#7706)

* feat(oidc): optimize the userinfo endpoint

* store project ID in the access token

* query for projectID if not in token

* add scope based tests

* Revert "store project ID in the access token"

This reverts commit 5f0262f239.

* query project role assertion

* use project role assertion setting to return roles

* workaround eventual consistency and handle PAT

* do not append empty project id
This commit is contained in:
Tim Möhlmann
2024-04-09 16:15:35 +03:00
committed by GitHub
parent c8e0b30e17
commit 6a51c4b0f5
25 changed files with 528 additions and 159 deletions

View File

@@ -28,14 +28,14 @@ var (
)
func TestOPStorage_CreateAuthRequest(t *testing.T) {
clientID := createClient(t)
clientID, _ := createClient(t)
id := createAuthRequest(t, clientID, redirectURI)
require.Contains(t, id, command.IDPrefixV2)
}
func TestOPStorage_CreateAccessToken_code(t *testing.T) {
clientID := createClient(t)
clientID, _ := createClient(t)
authRequestID := createAuthRequest(t, clientID, redirectURI)
sessionID, sessionToken, startTime, changeTime := Tester.CreateVerifiedWebAuthNSession(t, CTXLOGIN, User.GetUserId())
linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{
@@ -124,7 +124,7 @@ func TestOPStorage_CreateAccessToken_implicit(t *testing.T) {
}
func TestOPStorage_CreateAccessAndRefreshTokens_code(t *testing.T) {
clientID := createClient(t)
clientID, _ := createClient(t)
authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, oidc.ScopeOfflineAccess)
sessionID, sessionToken, startTime, changeTime := Tester.CreateVerifiedWebAuthNSession(t, CTXLOGIN, User.GetUserId())
linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{
@@ -147,7 +147,7 @@ func TestOPStorage_CreateAccessAndRefreshTokens_code(t *testing.T) {
}
func TestOPStorage_CreateAccessAndRefreshTokens_refresh(t *testing.T) {
clientID := createClient(t)
clientID, _ := createClient(t)
provider, err := Tester.CreateRelyingParty(CTX, clientID, redirectURI)
require.NoError(t, err)
authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, oidc.ScopeOfflineAccess)
@@ -183,7 +183,7 @@ func TestOPStorage_CreateAccessAndRefreshTokens_refresh(t *testing.T) {
}
func TestOPStorage_RevokeToken_access_token(t *testing.T) {
clientID := createClient(t)
clientID, _ := createClient(t)
provider, err := Tester.CreateRelyingParty(CTX, clientID, redirectURI)
require.NoError(t, err)
authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, oidc.ScopeOfflineAccess)
@@ -226,7 +226,7 @@ func TestOPStorage_RevokeToken_access_token(t *testing.T) {
}
func TestOPStorage_RevokeToken_access_token_invalid_token_hint_type(t *testing.T) {
clientID := createClient(t)
clientID, _ := createClient(t)
provider, err := Tester.CreateRelyingParty(CTX, clientID, redirectURI)
require.NoError(t, err)
authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, oidc.ScopeOfflineAccess)
@@ -263,7 +263,7 @@ func TestOPStorage_RevokeToken_access_token_invalid_token_hint_type(t *testing.T
}
func TestOPStorage_RevokeToken_refresh_token(t *testing.T) {
clientID := createClient(t)
clientID, _ := createClient(t)
provider, err := Tester.CreateRelyingParty(CTX, clientID, redirectURI)
require.NoError(t, err)
authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, oidc.ScopeOfflineAccess)
@@ -306,7 +306,7 @@ func TestOPStorage_RevokeToken_refresh_token(t *testing.T) {
}
func TestOPStorage_RevokeToken_refresh_token_invalid_token_type_hint(t *testing.T) {
clientID := createClient(t)
clientID, _ := createClient(t)
provider, err := Tester.CreateRelyingParty(CTX, clientID, redirectURI)
require.NoError(t, err)
authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, oidc.ScopeOfflineAccess)
@@ -343,7 +343,7 @@ func TestOPStorage_RevokeToken_refresh_token_invalid_token_type_hint(t *testing.
}
func TestOPStorage_RevokeToken_invalid_client(t *testing.T) {
clientID := createClient(t)
clientID, _ := createClient(t)
authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, oidc.ScopeOfflineAccess)
sessionID, sessionToken, startTime, changeTime := Tester.CreateVerifiedWebAuthNSession(t, CTXLOGIN, User.GetUserId())
linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{
@@ -365,7 +365,7 @@ func TestOPStorage_RevokeToken_invalid_client(t *testing.T) {
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
// simulate second client (not part of the audience) trying to revoke the token
otherClientID := createClient(t)
otherClientID, _ := createClient(t)
provider, err := Tester.CreateRelyingParty(CTX, otherClientID, redirectURI)
require.NoError(t, err)
err = rp.RevokeToken(CTX, provider, tokens.AccessToken, "")
@@ -373,7 +373,7 @@ func TestOPStorage_RevokeToken_invalid_client(t *testing.T) {
}
func TestOPStorage_TerminateSession(t *testing.T) {
clientID := createClient(t)
clientID, _ := createClient(t)
provider, err := Tester.CreateRelyingParty(CTX, clientID, redirectURI)
require.NoError(t, err)
authRequestID := createAuthRequest(t, clientID, redirectURI)
@@ -410,7 +410,7 @@ func TestOPStorage_TerminateSession(t *testing.T) {
}
func TestOPStorage_TerminateSession_refresh_grant(t *testing.T) {
clientID := createClient(t)
clientID, _ := createClient(t)
provider, err := Tester.CreateRelyingParty(CTX, clientID, redirectURI)
require.NoError(t, err)
authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, oidc.ScopeOfflineAccess)
@@ -454,7 +454,7 @@ func TestOPStorage_TerminateSession_refresh_grant(t *testing.T) {
}
func TestOPStorage_TerminateSession_empty_id_token_hint(t *testing.T) {
clientID := createClient(t)
clientID, _ := createClient(t)
provider, err := Tester.CreateRelyingParty(CTX, clientID, redirectURI)
require.NoError(t, err)
authRequestID := createAuthRequest(t, clientID, redirectURI)

View File

@@ -25,36 +25,6 @@ import (
"github.com/zitadel/zitadel/pkg/grpc/user"
)
func TestOPStorage_SetUserinfoFromToken(t *testing.T) {
clientID := createClient(t)
authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeOfflineAccess)
sessionID, sessionToken, startTime, changeTime := Tester.CreateVerifiedWebAuthNSession(t, CTXLOGIN, User.GetUserId())
linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{
AuthRequestId: authRequestID,
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
Session: &oidc_pb.Session{
SessionId: sessionID,
SessionToken: sessionToken,
},
},
})
require.NoError(t, err)
// code exchange
code := assertCodeResponse(t, linkResp.GetCallbackUrl())
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err)
assertTokens(t, tokens, true)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
// test actual userinfo
provider, err := Tester.CreateRelyingParty(CTX, clientID, redirectURI)
require.NoError(t, err)
userinfo, err := rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.Subject, provider)
require.NoError(t, err)
assertUserinfo(t, userinfo)
}
func TestServer_Introspect(t *testing.T) {
project, err := Tester.CreateProject(CTX)
require.NoError(t, err)
@@ -172,17 +142,6 @@ func TestServer_Introspect(t *testing.T) {
}
}
func assertUserinfo(t *testing.T, userinfo *oidc.UserInfo) {
assert.Equal(t, User.GetUserId(), userinfo.Subject)
assert.Equal(t, "Mickey", userinfo.GivenName)
assert.Equal(t, "Mouse", userinfo.FamilyName)
assert.Equal(t, "Mickey Mouse", userinfo.Name)
assert.NotEmpty(t, userinfo.PreferredUsername)
assert.Equal(t, userinfo.PreferredUsername, userinfo.Email)
assert.False(t, bool(userinfo.EmailVerified))
assertOIDCTime(t, userinfo.UpdatedAt, User.GetDetails().GetChangeDate().AsTime())
}
func assertIntrospection(
t *testing.T,
introspection *oidc.IntrospectionResponse,

View File

@@ -28,7 +28,6 @@ func (s *Server) Introspect(ctx context.Context, r *op.Request[op.IntrospectionR
return s.LegacyServer.Introspect(ctx, r)
}
if features.TriggerIntrospectionProjections {
// Execute all triggers in one concurrent sweep.
query.TriggerIntrospectionProjections(ctx)
}
@@ -100,7 +99,7 @@ func (s *Server) Introspect(ctx context.Context, r *op.Request[op.IntrospectionR
if err = validateIntrospectionAudience(token.audience, client.clientID, client.projectID); err != nil {
return nil, err
}
userInfo, err := s.userInfo(ctx, token.userID, client.projectID, token.scope, []string{client.projectID})
userInfo, err := s.userInfo(ctx, token.userID, client.projectID, client.projectRoleAssertion, token.scope, []string{client.projectID})
if err != nil {
return nil, err
}
@@ -124,9 +123,10 @@ func (s *Server) Introspect(ctx context.Context, r *op.Request[op.IntrospectionR
}
type introspectionClientResult struct {
clientID string
projectID string
err error
clientID string
projectID string
projectRoleAssertion bool
err error
}
var errNoClientSecret = errors.New("client has no configured secret")
@@ -134,35 +134,36 @@ var errNoClientSecret = errors.New("client has no configured secret")
func (s *Server) introspectionClientAuth(ctx context.Context, cc *op.ClientCredentials, rc chan<- *introspectionClientResult) {
ctx, span := tracing.NewSpan(ctx)
clientID, projectID, err := func() (string, string, error) {
clientID, projectID, projectRoleAssertion, err := func() (string, string, bool, error) {
client, err := s.clientFromCredentials(ctx, cc)
if err != nil {
return "", "", err
return "", "", false, err
}
if cc.ClientAssertion != "" {
verifier := op.NewJWTProfileVerifierKeySet(keySetMap(client.PublicKeys), op.IssuerFromContext(ctx), time.Hour, time.Second)
if _, err := op.VerifyJWTAssertion(ctx, cc.ClientAssertion, verifier); err != nil {
return "", "", oidc.ErrUnauthorizedClient().WithParent(err)
return "", "", false, oidc.ErrUnauthorizedClient().WithParent(err)
}
return client.ClientID, client.ProjectID, nil
return client.ClientID, client.ProjectID, client.ProjectRoleAssertion, nil
}
if client.HashedSecret != "" {
if err := s.introspectionClientSecretAuth(ctx, client, cc.ClientSecret); err != nil {
return "", "", oidc.ErrUnauthorizedClient().WithParent(err)
return "", "", false, oidc.ErrUnauthorizedClient().WithParent(err)
}
return client.ClientID, client.ProjectID, nil
return client.ClientID, client.ProjectID, client.ProjectRoleAssertion, nil
}
return "", "", oidc.ErrUnauthorizedClient().WithParent(errNoClientSecret)
return "", "", false, oidc.ErrUnauthorizedClient().WithParent(errNoClientSecret)
}()
span.EndWithError(err)
rc <- &introspectionClientResult{
clientID: clientID,
projectID: projectID,
err: err,
clientID: clientID,
projectID: projectID,
projectRoleAssertion: projectRoleAssertion,
err: err,
}
}

View File

@@ -39,23 +39,23 @@ const (
func TestMain(m *testing.M) {
os.Exit(func() int {
ctx, errCtx, cancel := integration.Contexts(10 * time.Minute)
ctx, _, cancel := integration.Contexts(10 * time.Minute)
defer cancel()
Tester = integration.NewTester(ctx)
defer Tester.Done()
CTX, _ = Tester.WithAuthorization(ctx, integration.OrgOwner), errCtx
CTX = Tester.WithAuthorization(ctx, integration.OrgOwner)
User = Tester.CreateHumanUser(CTX)
Tester.SetUserPassword(CTX, User.GetUserId(), integration.UserPassword, false)
Tester.RegisterUserPasskey(CTX, User.GetUserId())
CTXLOGIN, _ = Tester.WithAuthorization(ctx, integration.Login), errCtx
CTXLOGIN = Tester.WithAuthorization(ctx, integration.Login)
return m.Run()
}())
}
func Test_ZITADEL_API_missing_audience_scope(t *testing.T) {
clientID := createClient(t)
clientID, _ := createClient(t)
authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID)
sessionID, sessionToken, startTime, changeTime := Tester.CreateVerifiedWebAuthNSession(t, CTXLOGIN, User.GetUserId())
linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{
@@ -84,7 +84,7 @@ func Test_ZITADEL_API_missing_audience_scope(t *testing.T) {
}
func Test_ZITADEL_API_missing_authentication(t *testing.T) {
clientID := createClient(t)
clientID, _ := createClient(t)
authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, zitadelAudienceScope)
createResp, err := Tester.Client.SessionV2.CreateSession(CTX, &session.CreateSessionRequest{
Checks: &session.Checks{
@@ -118,7 +118,7 @@ func Test_ZITADEL_API_missing_authentication(t *testing.T) {
}
func Test_ZITADEL_API_missing_mfa(t *testing.T) {
clientID := createClient(t)
clientID, _ := createClient(t)
authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, zitadelAudienceScope)
sessionID, sessionToken, startTime, changeTime := Tester.CreatePasswordSession(t, CTX, User.GetUserId(), integration.UserPassword)
linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{
@@ -146,7 +146,7 @@ func Test_ZITADEL_API_missing_mfa(t *testing.T) {
}
func Test_ZITADEL_API_success(t *testing.T) {
clientID := createClient(t)
clientID, _ := createClient(t)
authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, zitadelAudienceScope)
sessionID, sessionToken, startTime, changeTime := Tester.CreateVerifiedWebAuthNSession(t, CTXLOGIN, User.GetUserId())
linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{
@@ -176,7 +176,7 @@ func Test_ZITADEL_API_success(t *testing.T) {
func Test_ZITADEL_API_glob_redirects(t *testing.T) {
const redirectURI = "https://my-org-1yfnjl2xj-my-app.vercel.app/api/auth/callback/zitadel"
clientID := createClientWithOpts(t, clientOpts{
clientID, _ := createClientWithOpts(t, clientOpts{
redirectURI: "https://my-org-*-my-app.vercel.app/api/auth/callback/zitadel",
logoutURI: "https://my-org-*-my-app.vercel.app/",
devMode: true,
@@ -209,7 +209,7 @@ func Test_ZITADEL_API_glob_redirects(t *testing.T) {
}
func Test_ZITADEL_API_inactive_access_token(t *testing.T) {
clientID := createClient(t)
clientID, _ := createClient(t)
authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, oidc.ScopeOfflineAccess, zitadelAudienceScope)
sessionID, sessionToken, startTime, changeTime := Tester.CreateVerifiedWebAuthNSession(t, CTXLOGIN, User.GetUserId())
linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{
@@ -249,7 +249,7 @@ func Test_ZITADEL_API_inactive_access_token(t *testing.T) {
}
func Test_ZITADEL_API_terminated_session(t *testing.T) {
clientID := createClient(t)
clientID, _ := createClient(t)
provider, err := Tester.CreateRelyingParty(CTX, clientID, redirectURI)
require.NoError(t, err)
authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, oidc.ScopeOfflineAccess, zitadelAudienceScope)
@@ -291,7 +291,7 @@ func Test_ZITADEL_API_terminated_session(t *testing.T) {
}
func Test_ZITADEL_API_terminated_session_user_disabled(t *testing.T) {
clientID := createClient(t)
clientID, _ := createClient(t)
tests := []struct {
name string
disable func(userID string) error
@@ -361,7 +361,7 @@ func Test_ZITADEL_API_terminated_session_user_disabled(t *testing.T) {
}
}
func createClient(t testing.TB) string {
func createClient(t testing.TB) (clientID, projectID string) {
return createClientWithOpts(t, clientOpts{
redirectURI: redirectURI,
logoutURI: logoutRedirectURI,
@@ -375,12 +375,12 @@ type clientOpts struct {
devMode bool
}
func createClientWithOpts(t testing.TB, opts clientOpts) string {
func createClientWithOpts(t testing.TB, opts clientOpts) (clientID, projectID string) {
project, err := Tester.CreateProject(CTX)
require.NoError(t, err)
app, err := Tester.CreateOIDCNativeClient(CTX, opts.redirectURI, opts.logoutURI, project.GetId(), opts.devMode)
require.NoError(t, err)
return app.GetClientId()
return app.GetClientId(), project.GetId()
}
func createImplicitClient(t testing.TB) string {

View File

@@ -188,13 +188,6 @@ func (s *Server) DeviceToken(ctx context.Context, r *op.ClientRequest[oidc.Devic
return s.LegacyServer.DeviceToken(ctx, r)
}
func (s *Server) UserInfo(ctx context.Context, r *op.Request[oidc.UserInfoRequest]) (_ *op.Response, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
return s.LegacyServer.UserInfo(ctx, r)
}
func (s *Server) Revocation(ctx context.Context, r *op.ClientRequest[oidc.RevocationRequest]) (_ *op.Response, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()

View File

@@ -216,7 +216,7 @@ func (s *Server) createExchangeTokens(ctx context.Context, tokenType oidc.TokenT
)
if slices.Contains(scopes, oidc.ScopeOpenID) || tokenType == oidc.JWTTokenType || tokenType == oidc.IDTokenType {
projectID := client.client.ProjectID
userInfo, err = s.userInfo(ctx, subjectToken.userID, projectID, scopes, []string{projectID})
userInfo, err = s.userInfo(ctx, subjectToken.userID, projectID, client.client.ProjectRoleAssertion, scopes, []string{projectID})
if err != nil {
return nil, err
}

View File

@@ -5,21 +5,63 @@ import (
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"slices"
"strings"
"github.com/dop251/goja"
"github.com/zitadel/logging"
"github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/op"
"github.com/zitadel/zitadel/internal/actions"
"github.com/zitadel/zitadel/internal/actions/object"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
func (s *Server) userInfo(ctx context.Context, userID, projectID string, scope, roleAudience []string) (_ *oidc.UserInfo, err error) {
roleAudience, requestedRoles := prepareRoles(ctx, projectID, scope, roleAudience)
func (s *Server) UserInfo(ctx context.Context, r *op.Request[oidc.UserInfoRequest]) (_ *op.Response, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() {
err = oidcError(err)
span.EndWithError(err)
}()
features := authz.GetFeatures(ctx)
if features.LegacyIntrospection {
return s.LegacyServer.UserInfo(ctx, r)
}
if features.TriggerIntrospectionProjections {
query.TriggerOIDCUserInfoProjections(ctx)
}
token, err := s.verifyAccessToken(ctx, r.Data.AccessToken)
if err != nil {
return nil, op.NewStatusError(oidc.ErrAccessDenied().WithDescription("access token invalid").WithParent(err), http.StatusUnauthorized)
}
var (
projectID string
assertion bool
)
if token.clientID != "" {
projectID, assertion, err = s.query.GetOIDCUserinfoClientByID(ctx, token.clientID)
if err != nil {
return nil, err
}
}
userInfo, err := s.userInfo(ctx, token.userID, projectID, assertion, token.scope, nil)
if err != nil {
return nil, err
}
return op.NewResponse(userInfo), nil
}
func (s *Server) userInfo(ctx context.Context, userID, projectID string, projectRoleAssertion bool, scope, roleAudience []string) (_ *oidc.UserInfo, err error) {
roleAudience, requestedRoles := prepareRoles(ctx, projectID, projectRoleAssertion, scope, roleAudience)
qu, err := s.query.GetOIDCUserInfo(ctx, userID, roleAudience)
if err != nil {
return nil, err
@@ -31,15 +73,14 @@ func (s *Server) userInfo(ctx context.Context, userID, projectID string, scope,
// prepareRoles scans the requested scopes, appends to roleAudience and returns the requestedRoles.
//
// Scopes with [ScopeProjectRolePrefix] are added to requestedRoles.
// When [ScopeProjectsRoles] is present and roleAudience was empty,
// project IDs with the [domain.ProjectIDScope] prefix are added to the roleAudience.
//
// Scopes with [ScopeProjectRolePrefix] are added to requestedRoles.
//
// If the resulting requestedRoles or roleAudience are not not empty,
// If projectRoleAssertion is true and the resulting requestedRoles or roleAudience are not empty,
// the current projectID will always be parts or roleAudience.
// Else nil, nil is returned.
func prepareRoles(ctx context.Context, projectID string, scope, roleAudience []string) (ra, requestedRoles []string) {
func prepareRoles(ctx context.Context, projectID string, projectRoleAssertion bool, scope, roleAudience []string) (ra, requestedRoles []string) {
// if all roles are requested take the audience for those from the scopes
if slices.Contains(scope, ScopeProjectsRoles) && len(roleAudience) == 0 {
roleAudience = domain.AddAudScopeToAudience(ctx, roleAudience, scope)
@@ -50,7 +91,7 @@ func prepareRoles(ctx context.Context, projectID string, scope, roleAudience []s
requestedRoles = append(requestedRoles, role)
}
}
if len(requestedRoles) == 0 && len(roleAudience) == 0 {
if !projectRoleAssertion && len(requestedRoles) == 0 && len(roleAudience) == 0 {
return nil, nil
}

View File

@@ -0,0 +1,270 @@
//go:build integration
package oidc_test
import (
"fmt"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v3/pkg/client/rp"
"github.com/zitadel/oidc/v3/pkg/oidc"
"golang.org/x/oauth2"
oidc_api "github.com/zitadel/zitadel/internal/api/oidc"
"github.com/zitadel/zitadel/internal/integration"
feature "github.com/zitadel/zitadel/pkg/grpc/feature/v2beta"
"github.com/zitadel/zitadel/pkg/grpc/management"
oidc_pb "github.com/zitadel/zitadel/pkg/grpc/oidc/v2beta"
)
// TestServer_UserInfo is a top-level test which re-executes the actual
// userinfo integration test against a matrix of different feature flags.
// This ensure that the response of the different implementations remains the same.
func TestServer_UserInfo(t *testing.T) {
iamOwnerCTX := Tester.WithAuthorization(CTX, integration.IAMOwner)
t.Cleanup(func() {
_, err := Tester.Client.FeatureV2.ResetInstanceFeatures(iamOwnerCTX, &feature.ResetInstanceFeaturesRequest{})
require.NoError(t, err)
})
tests := []struct {
name string
legacy bool
trigger bool
}{
{
name: "legacy enabled",
legacy: true,
},
{
name: "legacy disabled, trigger disabled",
legacy: false,
trigger: false,
},
{
name: "legacy disabled, trigger enabled",
legacy: false,
trigger: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := Tester.Client.FeatureV2.SetInstanceFeatures(iamOwnerCTX, &feature.SetInstanceFeaturesRequest{
OidcLegacyIntrospection: &tt.legacy,
OidcTriggerIntrospectionProjections: &tt.trigger,
})
require.NoError(t, err)
testServer_UserInfo(t)
})
}
}
// testServer_UserInfo is the actual userinfo integration test,
// which calls the userinfo endpoint with different client configurations, roles and token scopes.
func testServer_UserInfo(t *testing.T) {
const role = "testUserRole"
clientID, projectID := createClient(t)
_, err := Tester.Client.Mgmt.AddProjectRole(CTX, &management.AddProjectRoleRequest{
ProjectId: projectID,
RoleKey: role,
DisplayName: "test",
})
require.NoError(t, err)
_, err = Tester.Client.Mgmt.AddUserGrant(CTX, &management.AddUserGrantRequest{
UserId: User.GetUserId(),
ProjectId: projectID,
RoleKeys: []string{role},
})
require.NoError(t, err)
tests := []struct {
name string
prepare func(t *testing.T, clientID string, scope []string) *oidc.Tokens[*oidc.IDTokenClaims]
scope []string
assertions []func(*testing.T, *oidc.UserInfo)
wantErr bool
}{
{
name: "invalid token",
prepare: func(*testing.T, string, []string) *oidc.Tokens[*oidc.IDTokenClaims] {
return &oidc.Tokens[*oidc.IDTokenClaims]{
Token: &oauth2.Token{
AccessToken: "DEAFBEEFDEADBEEF",
TokenType: oidc.BearerToken,
},
IDTokenClaims: &oidc.IDTokenClaims{
TokenClaims: oidc.TokenClaims{
Subject: User.GetUserId(),
},
},
}
},
scope: []string{oidc.ScopeProfile, oidc.ScopeOpenID, oidc.ScopeEmail, oidc.ScopeOfflineAccess},
assertions: []func(*testing.T, *oidc.UserInfo){
func(t *testing.T, ui *oidc.UserInfo) {
assert.Nil(t, ui)
},
},
wantErr: true,
},
{
name: "standard scopes",
prepare: getTokens,
scope: []string{oidc.ScopeProfile, oidc.ScopeOpenID, oidc.ScopeEmail, oidc.ScopeOfflineAccess},
assertions: []func(*testing.T, *oidc.UserInfo){
assertUserinfo,
func(t *testing.T, ui *oidc.UserInfo) {
assertNoReservedScopes(t, ui.Claims)
},
},
},
{
name: "project role assertion",
prepare: func(t *testing.T, clientID string, scope []string) *oidc.Tokens[*oidc.IDTokenClaims] {
_, err := Tester.Client.Mgmt.UpdateProject(CTX, &management.UpdateProjectRequest{
Id: projectID,
Name: fmt.Sprintf("project-%d", time.Now().UnixNano()),
ProjectRoleAssertion: true,
})
require.NoError(t, err)
t.Cleanup(func() {
_, err := Tester.Client.Mgmt.UpdateProject(CTX, &management.UpdateProjectRequest{
Id: projectID,
Name: fmt.Sprintf("project-%d", time.Now().UnixNano()),
ProjectRoleAssertion: false,
})
require.NoError(t, err)
})
resp, err := Tester.Client.Mgmt.GetProjectByID(CTX, &management.GetProjectByIDRequest{Id: projectID})
require.NoError(t, err)
require.True(t, resp.GetProject().GetProjectRoleAssertion(), "project role assertion")
return getTokens(t, clientID, scope)
},
scope: []string{oidc.ScopeProfile, oidc.ScopeOpenID, oidc.ScopeEmail, oidc.ScopeOfflineAccess},
assertions: []func(*testing.T, *oidc.UserInfo){
assertUserinfo,
func(t *testing.T, ui *oidc.UserInfo) {
assertProjectRoleClaims(t, projectID, ui.Claims, role)
},
},
},
{
name: "projects roles scope",
prepare: getTokens,
scope: []string{oidc.ScopeProfile, oidc.ScopeOpenID, oidc.ScopeEmail, oidc.ScopeOfflineAccess, oidc_api.ScopeProjectRolePrefix + role},
assertions: []func(*testing.T, *oidc.UserInfo){
assertUserinfo,
func(t *testing.T, ui *oidc.UserInfo) {
assertProjectRoleClaims(t, projectID, ui.Claims, role)
},
},
},
{
name: "PAT",
prepare: func(t *testing.T, clientID string, scope []string) *oidc.Tokens[*oidc.IDTokenClaims] {
user := Tester.Users.Get(integration.FirstInstanceUsersKey, integration.OrgOwner)
return &oidc.Tokens[*oidc.IDTokenClaims]{
Token: &oauth2.Token{
AccessToken: user.Token,
TokenType: oidc.BearerToken,
},
IDTokenClaims: &oidc.IDTokenClaims{
TokenClaims: oidc.TokenClaims{
Subject: user.ID,
},
},
}
},
assertions: []func(*testing.T, *oidc.UserInfo){
func(t *testing.T, ui *oidc.UserInfo) {
user := Tester.Users.Get(integration.FirstInstanceUsersKey, integration.OrgOwner)
assert.Equal(t, user.ID, ui.Subject)
assert.Equal(t, user.PreferredLoginName, ui.PreferredUsername)
assert.Equal(t, user.Machine.Name, ui.Name)
assert.Equal(t, user.ResourceOwner, ui.Claims[oidc_api.ClaimResourceOwnerID])
assert.NotEmpty(t, ui.Claims[oidc_api.ClaimResourceOwnerName])
assert.NotEmpty(t, ui.Claims[oidc_api.ClaimResourceOwnerPrimaryDomain])
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tokens := tt.prepare(t, clientID, tt.scope)
provider, err := Tester.CreateRelyingParty(CTX, clientID, redirectURI)
require.NoError(t, err)
userinfo, err := rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.Subject, provider)
if tt.wantErr {
assert.Error(t, err)
return
}
require.NoError(t, err)
for _, assertion := range tt.assertions {
assertion(t, userinfo)
}
})
}
}
func getTokens(t *testing.T, clientID string, scope []string) *oidc.Tokens[*oidc.IDTokenClaims] {
authRequestID := createAuthRequest(t, clientID, redirectURI, scope...)
sessionID, sessionToken, startTime, changeTime := Tester.CreateVerifiedWebAuthNSession(t, CTXLOGIN, User.GetUserId())
linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{
AuthRequestId: authRequestID,
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
Session: &oidc_pb.Session{
SessionId: sessionID,
SessionToken: sessionToken,
},
},
})
require.NoError(t, err)
// code exchange
code := assertCodeResponse(t, linkResp.GetCallbackUrl())
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err)
assertTokens(t, tokens, true)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
return tokens
}
func assertUserinfo(t *testing.T, userinfo *oidc.UserInfo) {
t.Helper()
assert.Equal(t, User.GetUserId(), userinfo.Subject)
assert.Equal(t, "Mickey", userinfo.GivenName)
assert.Equal(t, "Mouse", userinfo.FamilyName)
assert.Equal(t, "Mickey Mouse", userinfo.Name)
assert.NotEmpty(t, userinfo.PreferredUsername)
assert.Equal(t, userinfo.PreferredUsername, userinfo.Email)
assert.False(t, bool(userinfo.EmailVerified))
assertOIDCTime(t, userinfo.UpdatedAt, User.GetDetails().GetChangeDate().AsTime())
}
func assertNoReservedScopes(t *testing.T, claims map[string]any) {
t.Helper()
t.Log(claims)
for claim := range claims {
assert.Falsef(t, strings.HasPrefix(claim, oidc_api.ClaimPrefix), "claim %s has prefix %s", claim, oidc_api.ClaimPrefix)
}
}
func assertProjectRoleClaims(t *testing.T, projectID string, claims map[string]any, roles ...string) {
t.Helper()
projectIDRoleClaim := fmt.Sprintf(oidc_api.ClaimProjectRolesFormat, projectID)
for _, claim := range []string{oidc_api.ClaimProjectRoles, projectIDRoleClaim} {
roleMap, ok := claims[claim].(map[string]any)
require.Truef(t, ok, "claim %s not found or wrong type %T", claim, claims[claim])
for _, roleKey := range roles {
role, ok := roleMap[roleKey].(map[string]any)
require.Truef(t, ok, "role %s not found or wrong type %T", roleKey, roleMap[roleKey])
assert.Equal(t, role[Tester.Organisation.ID], Tester.Organisation.Domain, "org domain in role")
}
}
}

View File

@@ -17,9 +17,10 @@ import (
func Test_prepareRoles(t *testing.T) {
type args struct {
projectID string
scope []string
roleAudience []string
projectID string
projectRoleAssertion bool
scope []string
roleAudience []string
}
tests := []struct {
name string
@@ -30,19 +31,32 @@ func Test_prepareRoles(t *testing.T) {
{
name: "empty scope and roleAudience",
args: args{
projectID: "projID",
scope: nil,
roleAudience: nil,
projectID: "projID",
projectRoleAssertion: false,
scope: nil,
roleAudience: nil,
},
wantRa: nil,
wantRequestedRoles: nil,
},
{
name: "project role assertion",
args: args{
projectID: "projID",
projectRoleAssertion: true,
scope: nil,
roleAudience: nil,
},
wantRa: []string{"projID"},
wantRequestedRoles: []string{},
},
{
name: "some scope and roleAudience",
args: args{
projectID: "projID",
scope: []string{"openid", "profile"},
roleAudience: []string{"project2"},
projectID: "projID",
projectRoleAssertion: false,
scope: []string{"openid", "profile"},
roleAudience: []string{"project2"},
},
wantRa: []string{"project2", "projID"},
wantRequestedRoles: []string{},
@@ -50,9 +64,10 @@ func Test_prepareRoles(t *testing.T) {
{
name: "scope projects roles",
args: args{
projectID: "projID",
scope: []string{ScopeProjectsRoles, domain.ProjectIDScope + "project2" + domain.AudSuffix},
roleAudience: nil,
projectID: "projID",
projectRoleAssertion: false,
scope: []string{ScopeProjectsRoles, domain.ProjectIDScope + "project2" + domain.AudSuffix},
roleAudience: nil,
},
wantRa: []string{"project2", "projID"},
wantRequestedRoles: []string{},
@@ -60,9 +75,10 @@ func Test_prepareRoles(t *testing.T) {
{
name: "scope project role prefix",
args: args{
projectID: "projID",
scope: []string{"openid", "profile", ScopeProjectRolePrefix + "foo", ScopeProjectRolePrefix + "bar"},
roleAudience: nil,
projectID: "projID",
projectRoleAssertion: false,
scope: []string{"openid", "profile", ScopeProjectRolePrefix + "foo", ScopeProjectRolePrefix + "bar"},
roleAudience: nil,
},
wantRa: []string{"projID"},
wantRequestedRoles: []string{"foo", "bar"},
@@ -70,7 +86,7 @@ func Test_prepareRoles(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotRa, gotRequestedRoles := prepareRoles(context.Background(), tt.args.projectID, tt.args.scope, tt.args.roleAudience)
gotRa, gotRequestedRoles := prepareRoles(context.Background(), tt.args.projectID, tt.args.projectRoleAssertion, tt.args.scope, tt.args.roleAudience)
assert.Equal(t, tt.wantRa, gotRa, "roleAudience")
assert.Equal(t, tt.wantRequestedRoles, gotRequestedRoles, "requestedRoles")
})