From 6a51c4b0f5af14ba80b91b63eafd017650fbff7e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20M=C3=B6hlmann?= Date: Tue, 9 Apr 2024 16:15:35 +0300 Subject: [PATCH] feat(oidc): optimize the userinfo endpoint (#7706) * feat(oidc): optimize the userinfo endpoint * store project ID in the access token * query for projectID if not in token * add scope based tests * Revert "store project ID in the access token" This reverts commit 5f0262f23988e7f62d415d0e4a02a705f6ad5197. * query project role assertion * use project role assertion setting to return roles * workaround eventual consistency and handle PAT * do not append empty project id --- .../api/oidc/auth_request_integration_test.go | 26 +- internal/api/oidc/client_integration_test.go | 41 --- internal/api/oidc/introspect.go | 31 +- internal/api/oidc/oidc_integration_test.go | 28 +- internal/api/oidc/server.go | 7 - internal/api/oidc/token_exchange.go | 2 +- internal/api/oidc/userinfo.go | 55 +++- .../api/oidc/userinfo_integration_test.go | 270 ++++++++++++++++++ internal/api/oidc/userinfo_test.go | 48 ++-- internal/query/auth_request.go | 2 +- .../query/{embed => }/auth_request_by_id.sql | 0 internal/query/introspection.go | 20 +- .../introspection_client_by_id.sql | 9 +- internal/query/introspection_test.go | 38 +-- internal/query/oidc_client.go | 3 +- .../query/{embed => }/oidc_client_by_id.sql | 12 +- internal/query/oidc_client_test.go | 4 + internal/query/testdata/oidc_client_jwt.json | 3 +- .../testdata/oidc_client_no_settings.json | 3 +- .../query/testdata/oidc_client_public.json | 3 +- .../query/testdata/oidc_client_secret.json | 3 +- internal/query/userinfo.go | 26 +- internal/query/{embed => }/userinfo_by_id.sql | 0 internal/query/userinfo_client_by_id.sql | 6 + internal/query/userinfo_test.go | 47 +++ 25 files changed, 528 insertions(+), 159 deletions(-) create mode 100644 internal/api/oidc/userinfo_integration_test.go rename internal/query/{embed => }/auth_request_by_id.sql (100%) rename internal/query/{embed => }/introspection_client_by_id.sql (61%) rename internal/query/{embed => }/oidc_client_by_id.sql (84%) rename internal/query/{embed => }/userinfo_by_id.sql (100%) create mode 100644 internal/query/userinfo_client_by_id.sql diff --git a/internal/api/oidc/auth_request_integration_test.go b/internal/api/oidc/auth_request_integration_test.go index 75306fbfb7..c36e06c6aa 100644 --- a/internal/api/oidc/auth_request_integration_test.go +++ b/internal/api/oidc/auth_request_integration_test.go @@ -28,14 +28,14 @@ var ( ) func TestOPStorage_CreateAuthRequest(t *testing.T) { - clientID := createClient(t) + clientID, _ := createClient(t) id := createAuthRequest(t, clientID, redirectURI) require.Contains(t, id, command.IDPrefixV2) } func TestOPStorage_CreateAccessToken_code(t *testing.T) { - clientID := createClient(t) + clientID, _ := createClient(t) authRequestID := createAuthRequest(t, clientID, redirectURI) sessionID, sessionToken, startTime, changeTime := Tester.CreateVerifiedWebAuthNSession(t, CTXLOGIN, User.GetUserId()) linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{ @@ -124,7 +124,7 @@ func TestOPStorage_CreateAccessToken_implicit(t *testing.T) { } func TestOPStorage_CreateAccessAndRefreshTokens_code(t *testing.T) { - clientID := createClient(t) + clientID, _ := createClient(t) authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, oidc.ScopeOfflineAccess) sessionID, sessionToken, startTime, changeTime := Tester.CreateVerifiedWebAuthNSession(t, CTXLOGIN, User.GetUserId()) linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{ @@ -147,7 +147,7 @@ func TestOPStorage_CreateAccessAndRefreshTokens_code(t *testing.T) { } func TestOPStorage_CreateAccessAndRefreshTokens_refresh(t *testing.T) { - clientID := createClient(t) + clientID, _ := createClient(t) provider, err := Tester.CreateRelyingParty(CTX, clientID, redirectURI) require.NoError(t, err) authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, oidc.ScopeOfflineAccess) @@ -183,7 +183,7 @@ func TestOPStorage_CreateAccessAndRefreshTokens_refresh(t *testing.T) { } func TestOPStorage_RevokeToken_access_token(t *testing.T) { - clientID := createClient(t) + clientID, _ := createClient(t) provider, err := Tester.CreateRelyingParty(CTX, clientID, redirectURI) require.NoError(t, err) authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, oidc.ScopeOfflineAccess) @@ -226,7 +226,7 @@ func TestOPStorage_RevokeToken_access_token(t *testing.T) { } func TestOPStorage_RevokeToken_access_token_invalid_token_hint_type(t *testing.T) { - clientID := createClient(t) + clientID, _ := createClient(t) provider, err := Tester.CreateRelyingParty(CTX, clientID, redirectURI) require.NoError(t, err) authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, oidc.ScopeOfflineAccess) @@ -263,7 +263,7 @@ func TestOPStorage_RevokeToken_access_token_invalid_token_hint_type(t *testing.T } func TestOPStorage_RevokeToken_refresh_token(t *testing.T) { - clientID := createClient(t) + clientID, _ := createClient(t) provider, err := Tester.CreateRelyingParty(CTX, clientID, redirectURI) require.NoError(t, err) authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, oidc.ScopeOfflineAccess) @@ -306,7 +306,7 @@ func TestOPStorage_RevokeToken_refresh_token(t *testing.T) { } func TestOPStorage_RevokeToken_refresh_token_invalid_token_type_hint(t *testing.T) { - clientID := createClient(t) + clientID, _ := createClient(t) provider, err := Tester.CreateRelyingParty(CTX, clientID, redirectURI) require.NoError(t, err) authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, oidc.ScopeOfflineAccess) @@ -343,7 +343,7 @@ func TestOPStorage_RevokeToken_refresh_token_invalid_token_type_hint(t *testing. } func TestOPStorage_RevokeToken_invalid_client(t *testing.T) { - clientID := createClient(t) + clientID, _ := createClient(t) authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, oidc.ScopeOfflineAccess) sessionID, sessionToken, startTime, changeTime := Tester.CreateVerifiedWebAuthNSession(t, CTXLOGIN, User.GetUserId()) linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{ @@ -365,7 +365,7 @@ func TestOPStorage_RevokeToken_invalid_client(t *testing.T) { assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) // simulate second client (not part of the audience) trying to revoke the token - otherClientID := createClient(t) + otherClientID, _ := createClient(t) provider, err := Tester.CreateRelyingParty(CTX, otherClientID, redirectURI) require.NoError(t, err) err = rp.RevokeToken(CTX, provider, tokens.AccessToken, "") @@ -373,7 +373,7 @@ func TestOPStorage_RevokeToken_invalid_client(t *testing.T) { } func TestOPStorage_TerminateSession(t *testing.T) { - clientID := createClient(t) + clientID, _ := createClient(t) provider, err := Tester.CreateRelyingParty(CTX, clientID, redirectURI) require.NoError(t, err) authRequestID := createAuthRequest(t, clientID, redirectURI) @@ -410,7 +410,7 @@ func TestOPStorage_TerminateSession(t *testing.T) { } func TestOPStorage_TerminateSession_refresh_grant(t *testing.T) { - clientID := createClient(t) + clientID, _ := createClient(t) provider, err := Tester.CreateRelyingParty(CTX, clientID, redirectURI) require.NoError(t, err) authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, oidc.ScopeOfflineAccess) @@ -454,7 +454,7 @@ func TestOPStorage_TerminateSession_refresh_grant(t *testing.T) { } func TestOPStorage_TerminateSession_empty_id_token_hint(t *testing.T) { - clientID := createClient(t) + clientID, _ := createClient(t) provider, err := Tester.CreateRelyingParty(CTX, clientID, redirectURI) require.NoError(t, err) authRequestID := createAuthRequest(t, clientID, redirectURI) diff --git a/internal/api/oidc/client_integration_test.go b/internal/api/oidc/client_integration_test.go index 96812120d8..c7ace3c097 100644 --- a/internal/api/oidc/client_integration_test.go +++ b/internal/api/oidc/client_integration_test.go @@ -25,36 +25,6 @@ import ( "github.com/zitadel/zitadel/pkg/grpc/user" ) -func TestOPStorage_SetUserinfoFromToken(t *testing.T) { - clientID := createClient(t) - authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeOfflineAccess) - sessionID, sessionToken, startTime, changeTime := Tester.CreateVerifiedWebAuthNSession(t, CTXLOGIN, User.GetUserId()) - linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{ - AuthRequestId: authRequestID, - CallbackKind: &oidc_pb.CreateCallbackRequest_Session{ - Session: &oidc_pb.Session{ - SessionId: sessionID, - SessionToken: sessionToken, - }, - }, - }) - require.NoError(t, err) - - // code exchange - code := assertCodeResponse(t, linkResp.GetCallbackUrl()) - tokens, err := exchangeTokens(t, clientID, code, redirectURI) - require.NoError(t, err) - assertTokens(t, tokens, true) - assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) - - // test actual userinfo - provider, err := Tester.CreateRelyingParty(CTX, clientID, redirectURI) - require.NoError(t, err) - userinfo, err := rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.Subject, provider) - require.NoError(t, err) - assertUserinfo(t, userinfo) -} - func TestServer_Introspect(t *testing.T) { project, err := Tester.CreateProject(CTX) require.NoError(t, err) @@ -172,17 +142,6 @@ func TestServer_Introspect(t *testing.T) { } } -func assertUserinfo(t *testing.T, userinfo *oidc.UserInfo) { - assert.Equal(t, User.GetUserId(), userinfo.Subject) - assert.Equal(t, "Mickey", userinfo.GivenName) - assert.Equal(t, "Mouse", userinfo.FamilyName) - assert.Equal(t, "Mickey Mouse", userinfo.Name) - assert.NotEmpty(t, userinfo.PreferredUsername) - assert.Equal(t, userinfo.PreferredUsername, userinfo.Email) - assert.False(t, bool(userinfo.EmailVerified)) - assertOIDCTime(t, userinfo.UpdatedAt, User.GetDetails().GetChangeDate().AsTime()) -} - func assertIntrospection( t *testing.T, introspection *oidc.IntrospectionResponse, diff --git a/internal/api/oidc/introspect.go b/internal/api/oidc/introspect.go index ccc061f8e5..0615193b03 100644 --- a/internal/api/oidc/introspect.go +++ b/internal/api/oidc/introspect.go @@ -28,7 +28,6 @@ func (s *Server) Introspect(ctx context.Context, r *op.Request[op.IntrospectionR return s.LegacyServer.Introspect(ctx, r) } if features.TriggerIntrospectionProjections { - // Execute all triggers in one concurrent sweep. query.TriggerIntrospectionProjections(ctx) } @@ -100,7 +99,7 @@ func (s *Server) Introspect(ctx context.Context, r *op.Request[op.IntrospectionR if err = validateIntrospectionAudience(token.audience, client.clientID, client.projectID); err != nil { return nil, err } - userInfo, err := s.userInfo(ctx, token.userID, client.projectID, token.scope, []string{client.projectID}) + userInfo, err := s.userInfo(ctx, token.userID, client.projectID, client.projectRoleAssertion, token.scope, []string{client.projectID}) if err != nil { return nil, err } @@ -124,9 +123,10 @@ func (s *Server) Introspect(ctx context.Context, r *op.Request[op.IntrospectionR } type introspectionClientResult struct { - clientID string - projectID string - err error + clientID string + projectID string + projectRoleAssertion bool + err error } var errNoClientSecret = errors.New("client has no configured secret") @@ -134,35 +134,36 @@ var errNoClientSecret = errors.New("client has no configured secret") func (s *Server) introspectionClientAuth(ctx context.Context, cc *op.ClientCredentials, rc chan<- *introspectionClientResult) { ctx, span := tracing.NewSpan(ctx) - clientID, projectID, err := func() (string, string, error) { + clientID, projectID, projectRoleAssertion, err := func() (string, string, bool, error) { client, err := s.clientFromCredentials(ctx, cc) if err != nil { - return "", "", err + return "", "", false, err } if cc.ClientAssertion != "" { verifier := op.NewJWTProfileVerifierKeySet(keySetMap(client.PublicKeys), op.IssuerFromContext(ctx), time.Hour, time.Second) if _, err := op.VerifyJWTAssertion(ctx, cc.ClientAssertion, verifier); err != nil { - return "", "", oidc.ErrUnauthorizedClient().WithParent(err) + return "", "", false, oidc.ErrUnauthorizedClient().WithParent(err) } - return client.ClientID, client.ProjectID, nil + return client.ClientID, client.ProjectID, client.ProjectRoleAssertion, nil } if client.HashedSecret != "" { if err := s.introspectionClientSecretAuth(ctx, client, cc.ClientSecret); err != nil { - return "", "", oidc.ErrUnauthorizedClient().WithParent(err) + return "", "", false, oidc.ErrUnauthorizedClient().WithParent(err) } - return client.ClientID, client.ProjectID, nil + return client.ClientID, client.ProjectID, client.ProjectRoleAssertion, nil } - return "", "", oidc.ErrUnauthorizedClient().WithParent(errNoClientSecret) + return "", "", false, oidc.ErrUnauthorizedClient().WithParent(errNoClientSecret) }() span.EndWithError(err) rc <- &introspectionClientResult{ - clientID: clientID, - projectID: projectID, - err: err, + clientID: clientID, + projectID: projectID, + projectRoleAssertion: projectRoleAssertion, + err: err, } } diff --git a/internal/api/oidc/oidc_integration_test.go b/internal/api/oidc/oidc_integration_test.go index 1e6f8ed118..09e76391bd 100644 --- a/internal/api/oidc/oidc_integration_test.go +++ b/internal/api/oidc/oidc_integration_test.go @@ -39,23 +39,23 @@ const ( func TestMain(m *testing.M) { os.Exit(func() int { - ctx, errCtx, cancel := integration.Contexts(10 * time.Minute) + ctx, _, cancel := integration.Contexts(10 * time.Minute) defer cancel() Tester = integration.NewTester(ctx) defer Tester.Done() - CTX, _ = Tester.WithAuthorization(ctx, integration.OrgOwner), errCtx + CTX = Tester.WithAuthorization(ctx, integration.OrgOwner) User = Tester.CreateHumanUser(CTX) Tester.SetUserPassword(CTX, User.GetUserId(), integration.UserPassword, false) Tester.RegisterUserPasskey(CTX, User.GetUserId()) - CTXLOGIN, _ = Tester.WithAuthorization(ctx, integration.Login), errCtx + CTXLOGIN = Tester.WithAuthorization(ctx, integration.Login) return m.Run() }()) } func Test_ZITADEL_API_missing_audience_scope(t *testing.T) { - clientID := createClient(t) + clientID, _ := createClient(t) authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID) sessionID, sessionToken, startTime, changeTime := Tester.CreateVerifiedWebAuthNSession(t, CTXLOGIN, User.GetUserId()) linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{ @@ -84,7 +84,7 @@ func Test_ZITADEL_API_missing_audience_scope(t *testing.T) { } func Test_ZITADEL_API_missing_authentication(t *testing.T) { - clientID := createClient(t) + clientID, _ := createClient(t) authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, zitadelAudienceScope) createResp, err := Tester.Client.SessionV2.CreateSession(CTX, &session.CreateSessionRequest{ Checks: &session.Checks{ @@ -118,7 +118,7 @@ func Test_ZITADEL_API_missing_authentication(t *testing.T) { } func Test_ZITADEL_API_missing_mfa(t *testing.T) { - clientID := createClient(t) + clientID, _ := createClient(t) authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, zitadelAudienceScope) sessionID, sessionToken, startTime, changeTime := Tester.CreatePasswordSession(t, CTX, User.GetUserId(), integration.UserPassword) linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{ @@ -146,7 +146,7 @@ func Test_ZITADEL_API_missing_mfa(t *testing.T) { } func Test_ZITADEL_API_success(t *testing.T) { - clientID := createClient(t) + clientID, _ := createClient(t) authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, zitadelAudienceScope) sessionID, sessionToken, startTime, changeTime := Tester.CreateVerifiedWebAuthNSession(t, CTXLOGIN, User.GetUserId()) linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{ @@ -176,7 +176,7 @@ func Test_ZITADEL_API_success(t *testing.T) { func Test_ZITADEL_API_glob_redirects(t *testing.T) { const redirectURI = "https://my-org-1yfnjl2xj-my-app.vercel.app/api/auth/callback/zitadel" - clientID := createClientWithOpts(t, clientOpts{ + clientID, _ := createClientWithOpts(t, clientOpts{ redirectURI: "https://my-org-*-my-app.vercel.app/api/auth/callback/zitadel", logoutURI: "https://my-org-*-my-app.vercel.app/", devMode: true, @@ -209,7 +209,7 @@ func Test_ZITADEL_API_glob_redirects(t *testing.T) { } func Test_ZITADEL_API_inactive_access_token(t *testing.T) { - clientID := createClient(t) + clientID, _ := createClient(t) authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, oidc.ScopeOfflineAccess, zitadelAudienceScope) sessionID, sessionToken, startTime, changeTime := Tester.CreateVerifiedWebAuthNSession(t, CTXLOGIN, User.GetUserId()) linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{ @@ -249,7 +249,7 @@ func Test_ZITADEL_API_inactive_access_token(t *testing.T) { } func Test_ZITADEL_API_terminated_session(t *testing.T) { - clientID := createClient(t) + clientID, _ := createClient(t) provider, err := Tester.CreateRelyingParty(CTX, clientID, redirectURI) require.NoError(t, err) authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, oidc.ScopeOfflineAccess, zitadelAudienceScope) @@ -291,7 +291,7 @@ func Test_ZITADEL_API_terminated_session(t *testing.T) { } func Test_ZITADEL_API_terminated_session_user_disabled(t *testing.T) { - clientID := createClient(t) + clientID, _ := createClient(t) tests := []struct { name string disable func(userID string) error @@ -361,7 +361,7 @@ func Test_ZITADEL_API_terminated_session_user_disabled(t *testing.T) { } } -func createClient(t testing.TB) string { +func createClient(t testing.TB) (clientID, projectID string) { return createClientWithOpts(t, clientOpts{ redirectURI: redirectURI, logoutURI: logoutRedirectURI, @@ -375,12 +375,12 @@ type clientOpts struct { devMode bool } -func createClientWithOpts(t testing.TB, opts clientOpts) string { +func createClientWithOpts(t testing.TB, opts clientOpts) (clientID, projectID string) { project, err := Tester.CreateProject(CTX) require.NoError(t, err) app, err := Tester.CreateOIDCNativeClient(CTX, opts.redirectURI, opts.logoutURI, project.GetId(), opts.devMode) require.NoError(t, err) - return app.GetClientId() + return app.GetClientId(), project.GetId() } func createImplicitClient(t testing.TB) string { diff --git a/internal/api/oidc/server.go b/internal/api/oidc/server.go index 9b051c7785..e164740539 100644 --- a/internal/api/oidc/server.go +++ b/internal/api/oidc/server.go @@ -188,13 +188,6 @@ func (s *Server) DeviceToken(ctx context.Context, r *op.ClientRequest[oidc.Devic return s.LegacyServer.DeviceToken(ctx, r) } -func (s *Server) UserInfo(ctx context.Context, r *op.Request[oidc.UserInfoRequest]) (_ *op.Response, err error) { - ctx, span := tracing.NewSpan(ctx) - defer func() { span.EndWithError(err) }() - - return s.LegacyServer.UserInfo(ctx, r) -} - func (s *Server) Revocation(ctx context.Context, r *op.ClientRequest[oidc.RevocationRequest]) (_ *op.Response, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() diff --git a/internal/api/oidc/token_exchange.go b/internal/api/oidc/token_exchange.go index 77c9b46fa2..c50cf5859d 100644 --- a/internal/api/oidc/token_exchange.go +++ b/internal/api/oidc/token_exchange.go @@ -216,7 +216,7 @@ func (s *Server) createExchangeTokens(ctx context.Context, tokenType oidc.TokenT ) if slices.Contains(scopes, oidc.ScopeOpenID) || tokenType == oidc.JWTTokenType || tokenType == oidc.IDTokenType { projectID := client.client.ProjectID - userInfo, err = s.userInfo(ctx, subjectToken.userID, projectID, scopes, []string{projectID}) + userInfo, err = s.userInfo(ctx, subjectToken.userID, projectID, client.client.ProjectRoleAssertion, scopes, []string{projectID}) if err != nil { return nil, err } diff --git a/internal/api/oidc/userinfo.go b/internal/api/oidc/userinfo.go index a960c6ceca..90a77fa202 100644 --- a/internal/api/oidc/userinfo.go +++ b/internal/api/oidc/userinfo.go @@ -5,21 +5,63 @@ import ( "encoding/base64" "encoding/json" "fmt" + "net/http" "slices" "strings" "github.com/dop251/goja" "github.com/zitadel/logging" "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/op" "github.com/zitadel/zitadel/internal/actions" "github.com/zitadel/zitadel/internal/actions/object" + "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/query" + "github.com/zitadel/zitadel/internal/telemetry/tracing" ) -func (s *Server) userInfo(ctx context.Context, userID, projectID string, scope, roleAudience []string) (_ *oidc.UserInfo, err error) { - roleAudience, requestedRoles := prepareRoles(ctx, projectID, scope, roleAudience) +func (s *Server) UserInfo(ctx context.Context, r *op.Request[oidc.UserInfoRequest]) (_ *op.Response, err error) { + ctx, span := tracing.NewSpan(ctx) + defer func() { + err = oidcError(err) + span.EndWithError(err) + }() + + features := authz.GetFeatures(ctx) + if features.LegacyIntrospection { + return s.LegacyServer.UserInfo(ctx, r) + } + if features.TriggerIntrospectionProjections { + query.TriggerOIDCUserInfoProjections(ctx) + } + + token, err := s.verifyAccessToken(ctx, r.Data.AccessToken) + if err != nil { + return nil, op.NewStatusError(oidc.ErrAccessDenied().WithDescription("access token invalid").WithParent(err), http.StatusUnauthorized) + } + + var ( + projectID string + assertion bool + ) + if token.clientID != "" { + projectID, assertion, err = s.query.GetOIDCUserinfoClientByID(ctx, token.clientID) + if err != nil { + return nil, err + } + } + + userInfo, err := s.userInfo(ctx, token.userID, projectID, assertion, token.scope, nil) + if err != nil { + return nil, err + } + return op.NewResponse(userInfo), nil +} + +func (s *Server) userInfo(ctx context.Context, userID, projectID string, projectRoleAssertion bool, scope, roleAudience []string) (_ *oidc.UserInfo, err error) { + roleAudience, requestedRoles := prepareRoles(ctx, projectID, projectRoleAssertion, scope, roleAudience) qu, err := s.query.GetOIDCUserInfo(ctx, userID, roleAudience) if err != nil { return nil, err @@ -31,15 +73,14 @@ func (s *Server) userInfo(ctx context.Context, userID, projectID string, scope, // prepareRoles scans the requested scopes, appends to roleAudience and returns the requestedRoles. // +// Scopes with [ScopeProjectRolePrefix] are added to requestedRoles. // When [ScopeProjectsRoles] is present and roleAudience was empty, // project IDs with the [domain.ProjectIDScope] prefix are added to the roleAudience. // -// Scopes with [ScopeProjectRolePrefix] are added to requestedRoles. -// -// If the resulting requestedRoles or roleAudience are not not empty, +// If projectRoleAssertion is true and the resulting requestedRoles or roleAudience are not empty, // the current projectID will always be parts or roleAudience. // Else nil, nil is returned. -func prepareRoles(ctx context.Context, projectID string, scope, roleAudience []string) (ra, requestedRoles []string) { +func prepareRoles(ctx context.Context, projectID string, projectRoleAssertion bool, scope, roleAudience []string) (ra, requestedRoles []string) { // if all roles are requested take the audience for those from the scopes if slices.Contains(scope, ScopeProjectsRoles) && len(roleAudience) == 0 { roleAudience = domain.AddAudScopeToAudience(ctx, roleAudience, scope) @@ -50,7 +91,7 @@ func prepareRoles(ctx context.Context, projectID string, scope, roleAudience []s requestedRoles = append(requestedRoles, role) } } - if len(requestedRoles) == 0 && len(roleAudience) == 0 { + if !projectRoleAssertion && len(requestedRoles) == 0 && len(roleAudience) == 0 { return nil, nil } diff --git a/internal/api/oidc/userinfo_integration_test.go b/internal/api/oidc/userinfo_integration_test.go new file mode 100644 index 0000000000..78cd5479ed --- /dev/null +++ b/internal/api/oidc/userinfo_integration_test.go @@ -0,0 +1,270 @@ +//go:build integration + +package oidc_test + +import ( + "fmt" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/zitadel/oidc/v3/pkg/client/rp" + "github.com/zitadel/oidc/v3/pkg/oidc" + "golang.org/x/oauth2" + + oidc_api "github.com/zitadel/zitadel/internal/api/oidc" + "github.com/zitadel/zitadel/internal/integration" + feature "github.com/zitadel/zitadel/pkg/grpc/feature/v2beta" + "github.com/zitadel/zitadel/pkg/grpc/management" + oidc_pb "github.com/zitadel/zitadel/pkg/grpc/oidc/v2beta" +) + +// TestServer_UserInfo is a top-level test which re-executes the actual +// userinfo integration test against a matrix of different feature flags. +// This ensure that the response of the different implementations remains the same. +func TestServer_UserInfo(t *testing.T) { + iamOwnerCTX := Tester.WithAuthorization(CTX, integration.IAMOwner) + t.Cleanup(func() { + _, err := Tester.Client.FeatureV2.ResetInstanceFeatures(iamOwnerCTX, &feature.ResetInstanceFeaturesRequest{}) + require.NoError(t, err) + }) + tests := []struct { + name string + legacy bool + trigger bool + }{ + { + name: "legacy enabled", + legacy: true, + }, + { + name: "legacy disabled, trigger disabled", + legacy: false, + trigger: false, + }, + { + name: "legacy disabled, trigger enabled", + legacy: false, + trigger: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := Tester.Client.FeatureV2.SetInstanceFeatures(iamOwnerCTX, &feature.SetInstanceFeaturesRequest{ + OidcLegacyIntrospection: &tt.legacy, + OidcTriggerIntrospectionProjections: &tt.trigger, + }) + require.NoError(t, err) + testServer_UserInfo(t) + }) + } +} + +// testServer_UserInfo is the actual userinfo integration test, +// which calls the userinfo endpoint with different client configurations, roles and token scopes. +func testServer_UserInfo(t *testing.T) { + const role = "testUserRole" + clientID, projectID := createClient(t) + _, err := Tester.Client.Mgmt.AddProjectRole(CTX, &management.AddProjectRoleRequest{ + ProjectId: projectID, + RoleKey: role, + DisplayName: "test", + }) + require.NoError(t, err) + _, err = Tester.Client.Mgmt.AddUserGrant(CTX, &management.AddUserGrantRequest{ + UserId: User.GetUserId(), + ProjectId: projectID, + RoleKeys: []string{role}, + }) + require.NoError(t, err) + + tests := []struct { + name string + prepare func(t *testing.T, clientID string, scope []string) *oidc.Tokens[*oidc.IDTokenClaims] + scope []string + assertions []func(*testing.T, *oidc.UserInfo) + wantErr bool + }{ + { + name: "invalid token", + prepare: func(*testing.T, string, []string) *oidc.Tokens[*oidc.IDTokenClaims] { + return &oidc.Tokens[*oidc.IDTokenClaims]{ + Token: &oauth2.Token{ + AccessToken: "DEAFBEEFDEADBEEF", + TokenType: oidc.BearerToken, + }, + IDTokenClaims: &oidc.IDTokenClaims{ + TokenClaims: oidc.TokenClaims{ + Subject: User.GetUserId(), + }, + }, + } + }, + scope: []string{oidc.ScopeProfile, oidc.ScopeOpenID, oidc.ScopeEmail, oidc.ScopeOfflineAccess}, + assertions: []func(*testing.T, *oidc.UserInfo){ + func(t *testing.T, ui *oidc.UserInfo) { + assert.Nil(t, ui) + }, + }, + wantErr: true, + }, + { + name: "standard scopes", + prepare: getTokens, + scope: []string{oidc.ScopeProfile, oidc.ScopeOpenID, oidc.ScopeEmail, oidc.ScopeOfflineAccess}, + assertions: []func(*testing.T, *oidc.UserInfo){ + assertUserinfo, + func(t *testing.T, ui *oidc.UserInfo) { + assertNoReservedScopes(t, ui.Claims) + }, + }, + }, + { + name: "project role assertion", + prepare: func(t *testing.T, clientID string, scope []string) *oidc.Tokens[*oidc.IDTokenClaims] { + _, err := Tester.Client.Mgmt.UpdateProject(CTX, &management.UpdateProjectRequest{ + Id: projectID, + Name: fmt.Sprintf("project-%d", time.Now().UnixNano()), + ProjectRoleAssertion: true, + }) + require.NoError(t, err) + t.Cleanup(func() { + _, err := Tester.Client.Mgmt.UpdateProject(CTX, &management.UpdateProjectRequest{ + Id: projectID, + Name: fmt.Sprintf("project-%d", time.Now().UnixNano()), + ProjectRoleAssertion: false, + }) + require.NoError(t, err) + }) + resp, err := Tester.Client.Mgmt.GetProjectByID(CTX, &management.GetProjectByIDRequest{Id: projectID}) + require.NoError(t, err) + require.True(t, resp.GetProject().GetProjectRoleAssertion(), "project role assertion") + + return getTokens(t, clientID, scope) + }, + scope: []string{oidc.ScopeProfile, oidc.ScopeOpenID, oidc.ScopeEmail, oidc.ScopeOfflineAccess}, + assertions: []func(*testing.T, *oidc.UserInfo){ + assertUserinfo, + func(t *testing.T, ui *oidc.UserInfo) { + assertProjectRoleClaims(t, projectID, ui.Claims, role) + }, + }, + }, + { + name: "projects roles scope", + prepare: getTokens, + scope: []string{oidc.ScopeProfile, oidc.ScopeOpenID, oidc.ScopeEmail, oidc.ScopeOfflineAccess, oidc_api.ScopeProjectRolePrefix + role}, + assertions: []func(*testing.T, *oidc.UserInfo){ + assertUserinfo, + func(t *testing.T, ui *oidc.UserInfo) { + assertProjectRoleClaims(t, projectID, ui.Claims, role) + }, + }, + }, + { + name: "PAT", + prepare: func(t *testing.T, clientID string, scope []string) *oidc.Tokens[*oidc.IDTokenClaims] { + user := Tester.Users.Get(integration.FirstInstanceUsersKey, integration.OrgOwner) + return &oidc.Tokens[*oidc.IDTokenClaims]{ + Token: &oauth2.Token{ + AccessToken: user.Token, + TokenType: oidc.BearerToken, + }, + IDTokenClaims: &oidc.IDTokenClaims{ + TokenClaims: oidc.TokenClaims{ + Subject: user.ID, + }, + }, + } + }, + assertions: []func(*testing.T, *oidc.UserInfo){ + func(t *testing.T, ui *oidc.UserInfo) { + user := Tester.Users.Get(integration.FirstInstanceUsersKey, integration.OrgOwner) + assert.Equal(t, user.ID, ui.Subject) + assert.Equal(t, user.PreferredLoginName, ui.PreferredUsername) + assert.Equal(t, user.Machine.Name, ui.Name) + assert.Equal(t, user.ResourceOwner, ui.Claims[oidc_api.ClaimResourceOwnerID]) + assert.NotEmpty(t, ui.Claims[oidc_api.ClaimResourceOwnerName]) + assert.NotEmpty(t, ui.Claims[oidc_api.ClaimResourceOwnerPrimaryDomain]) + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tokens := tt.prepare(t, clientID, tt.scope) + provider, err := Tester.CreateRelyingParty(CTX, clientID, redirectURI) + require.NoError(t, err) + userinfo, err := rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.Subject, provider) + if tt.wantErr { + assert.Error(t, err) + return + } + require.NoError(t, err) + for _, assertion := range tt.assertions { + assertion(t, userinfo) + } + }) + } +} + +func getTokens(t *testing.T, clientID string, scope []string) *oidc.Tokens[*oidc.IDTokenClaims] { + authRequestID := createAuthRequest(t, clientID, redirectURI, scope...) + sessionID, sessionToken, startTime, changeTime := Tester.CreateVerifiedWebAuthNSession(t, CTXLOGIN, User.GetUserId()) + linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{ + AuthRequestId: authRequestID, + CallbackKind: &oidc_pb.CreateCallbackRequest_Session{ + Session: &oidc_pb.Session{ + SessionId: sessionID, + SessionToken: sessionToken, + }, + }, + }) + require.NoError(t, err) + + // code exchange + code := assertCodeResponse(t, linkResp.GetCallbackUrl()) + tokens, err := exchangeTokens(t, clientID, code, redirectURI) + require.NoError(t, err) + assertTokens(t, tokens, true) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) + + return tokens +} + +func assertUserinfo(t *testing.T, userinfo *oidc.UserInfo) { + t.Helper() + assert.Equal(t, User.GetUserId(), userinfo.Subject) + assert.Equal(t, "Mickey", userinfo.GivenName) + assert.Equal(t, "Mouse", userinfo.FamilyName) + assert.Equal(t, "Mickey Mouse", userinfo.Name) + assert.NotEmpty(t, userinfo.PreferredUsername) + assert.Equal(t, userinfo.PreferredUsername, userinfo.Email) + assert.False(t, bool(userinfo.EmailVerified)) + assertOIDCTime(t, userinfo.UpdatedAt, User.GetDetails().GetChangeDate().AsTime()) +} + +func assertNoReservedScopes(t *testing.T, claims map[string]any) { + t.Helper() + t.Log(claims) + for claim := range claims { + assert.Falsef(t, strings.HasPrefix(claim, oidc_api.ClaimPrefix), "claim %s has prefix %s", claim, oidc_api.ClaimPrefix) + } +} + +func assertProjectRoleClaims(t *testing.T, projectID string, claims map[string]any, roles ...string) { + t.Helper() + projectIDRoleClaim := fmt.Sprintf(oidc_api.ClaimProjectRolesFormat, projectID) + for _, claim := range []string{oidc_api.ClaimProjectRoles, projectIDRoleClaim} { + roleMap, ok := claims[claim].(map[string]any) + require.Truef(t, ok, "claim %s not found or wrong type %T", claim, claims[claim]) + for _, roleKey := range roles { + role, ok := roleMap[roleKey].(map[string]any) + require.Truef(t, ok, "role %s not found or wrong type %T", roleKey, roleMap[roleKey]) + assert.Equal(t, role[Tester.Organisation.ID], Tester.Organisation.Domain, "org domain in role") + } + } +} diff --git a/internal/api/oidc/userinfo_test.go b/internal/api/oidc/userinfo_test.go index d241f0d86c..21e06e21c4 100644 --- a/internal/api/oidc/userinfo_test.go +++ b/internal/api/oidc/userinfo_test.go @@ -17,9 +17,10 @@ import ( func Test_prepareRoles(t *testing.T) { type args struct { - projectID string - scope []string - roleAudience []string + projectID string + projectRoleAssertion bool + scope []string + roleAudience []string } tests := []struct { name string @@ -30,19 +31,32 @@ func Test_prepareRoles(t *testing.T) { { name: "empty scope and roleAudience", args: args{ - projectID: "projID", - scope: nil, - roleAudience: nil, + projectID: "projID", + projectRoleAssertion: false, + scope: nil, + roleAudience: nil, }, wantRa: nil, wantRequestedRoles: nil, }, + { + name: "project role assertion", + args: args{ + projectID: "projID", + projectRoleAssertion: true, + scope: nil, + roleAudience: nil, + }, + wantRa: []string{"projID"}, + wantRequestedRoles: []string{}, + }, { name: "some scope and roleAudience", args: args{ - projectID: "projID", - scope: []string{"openid", "profile"}, - roleAudience: []string{"project2"}, + projectID: "projID", + projectRoleAssertion: false, + scope: []string{"openid", "profile"}, + roleAudience: []string{"project2"}, }, wantRa: []string{"project2", "projID"}, wantRequestedRoles: []string{}, @@ -50,9 +64,10 @@ func Test_prepareRoles(t *testing.T) { { name: "scope projects roles", args: args{ - projectID: "projID", - scope: []string{ScopeProjectsRoles, domain.ProjectIDScope + "project2" + domain.AudSuffix}, - roleAudience: nil, + projectID: "projID", + projectRoleAssertion: false, + scope: []string{ScopeProjectsRoles, domain.ProjectIDScope + "project2" + domain.AudSuffix}, + roleAudience: nil, }, wantRa: []string{"project2", "projID"}, wantRequestedRoles: []string{}, @@ -60,9 +75,10 @@ func Test_prepareRoles(t *testing.T) { { name: "scope project role prefix", args: args{ - projectID: "projID", - scope: []string{"openid", "profile", ScopeProjectRolePrefix + "foo", ScopeProjectRolePrefix + "bar"}, - roleAudience: nil, + projectID: "projID", + projectRoleAssertion: false, + scope: []string{"openid", "profile", ScopeProjectRolePrefix + "foo", ScopeProjectRolePrefix + "bar"}, + roleAudience: nil, }, wantRa: []string{"projID"}, wantRequestedRoles: []string{"foo", "bar"}, @@ -70,7 +86,7 @@ func Test_prepareRoles(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotRa, gotRequestedRoles := prepareRoles(context.Background(), tt.args.projectID, tt.args.scope, tt.args.roleAudience) + gotRa, gotRequestedRoles := prepareRoles(context.Background(), tt.args.projectID, tt.args.projectRoleAssertion, tt.args.scope, tt.args.roleAudience) assert.Equal(t, tt.wantRa, gotRa, "roleAudience") assert.Equal(t, tt.wantRequestedRoles, gotRequestedRoles, "requestedRoles") }) diff --git a/internal/query/auth_request.go b/internal/query/auth_request.go index 9de2cf9db0..c0554778ab 100644 --- a/internal/query/auth_request.go +++ b/internal/query/auth_request.go @@ -41,7 +41,7 @@ func (a *AuthRequest) checkLoginClient(ctx context.Context) error { return nil } -//go:embed embed/auth_request_by_id.sql +//go:embed auth_request_by_id.sql var authRequestByIDQuery string func (q *Queries) authRequestByIDQuery(ctx context.Context) string { diff --git a/internal/query/embed/auth_request_by_id.sql b/internal/query/auth_request_by_id.sql similarity index 100% rename from internal/query/embed/auth_request_by_id.sql rename to internal/query/auth_request_by_id.sql diff --git a/internal/query/introspection.go b/internal/query/introspection.go index 3f516062fe..a7fdaab718 100644 --- a/internal/query/introspection.go +++ b/internal/query/introspection.go @@ -25,6 +25,8 @@ var introspectionTriggerHandlers = sync.OnceValue(func() []*handler.Handler { ) }) +// TriggerIntrospectionProjections triggers all projections +// relevant to introspection queries concurrently. func TriggerIntrospectionProjections(ctx context.Context) { triggerBatch(ctx, introspectionTriggerHandlers()...) } @@ -37,16 +39,17 @@ const ( ) type IntrospectionClient struct { - AppID string - ClientID string - HashedSecret string - AppType AppType - ProjectID string - ResourceOwner string - PublicKeys database.Map[[]byte] + AppID string + ClientID string + HashedSecret string + AppType AppType + ProjectID string + ResourceOwner string + ProjectRoleAssertion bool + PublicKeys database.Map[[]byte] } -//go:embed embed/introspection_client_by_id.sql +//go:embed introspection_client_by_id.sql var introspectionClientByIDQuery string func (q *Queries) GetIntrospectionClientByID(ctx context.Context, clientID string, getKeys bool) (_ *IntrospectionClient, err error) { @@ -66,6 +69,7 @@ func (q *Queries) GetIntrospectionClientByID(ctx context.Context, clientID strin &client.AppType, &client.ProjectID, &client.ResourceOwner, + &client.ProjectRoleAssertion, &client.PublicKeys, ) }, diff --git a/internal/query/embed/introspection_client_by_id.sql b/internal/query/introspection_client_by_id.sql similarity index 61% rename from internal/query/embed/introspection_client_by_id.sql rename to internal/query/introspection_client_by_id.sql index 9bceff9118..5129d99a70 100644 --- a/internal/query/embed/introspection_client_by_id.sql +++ b/internal/query/introspection_client_by_id.sql @@ -1,10 +1,10 @@ with config as ( - select app_id, client_id, client_secret, 'api' as app_type + select instance_id, app_id, client_id, client_secret, 'api' as app_type from projections.apps7_api_configs where instance_id = $1 and client_id = $2 union - select app_id, client_id, client_secret, 'oidc' as app_type + select instance_id, app_id, client_id, client_secret, 'oidc' as app_type from projections.apps7_oidc_configs where instance_id = $1 and client_id = $2 @@ -18,7 +18,8 @@ keys as ( and expiration > current_timestamp group by identifier ) -select config.app_id, config.client_id, config.client_secret, config.app_type, apps.project_id, apps.resource_owner, keys.public_keys +select config.app_id, config.client_id, config.client_secret, config.app_type, apps.project_id, apps.resource_owner, p.project_role_assertion, keys.public_keys from config -join projections.apps7 apps on apps.id = config.app_id +join projections.apps7 apps on apps.id = config.app_id and apps.instance_id = config.instance_id +join projections.projects4 p on p.id = apps.project_id and p.instance_id = $1 left join keys on keys.client_id = config.client_id; diff --git a/internal/query/introspection_test.go b/internal/query/introspection_test.go index 04998c65cf..6535bd1639 100644 --- a/internal/query/introspection_test.go +++ b/internal/query/introspection_test.go @@ -50,17 +50,18 @@ func TestQueries_GetIntrospectionClientByID(t *testing.T) { getKeys: false, }, mock: mockQuery(expQuery, - []string{"app_id", "client_id", "client_secret", "app_type", "project_id", "resource_owner", "public_keys"}, - []driver.Value{"appID", "clientID", "secret", "oidc", "projectID", "orgID", nil}, + []string{"app_id", "client_id", "client_secret", "app_type", "project_id", "resource_owner", "project_role_assertion", "public_keys"}, + []driver.Value{"appID", "clientID", "secret", "oidc", "projectID", "orgID", true, nil}, "instanceID", "clientID", false), want: &IntrospectionClient{ - AppID: "appID", - ClientID: "clientID", - HashedSecret: "secret", - AppType: AppTypeOIDC, - ProjectID: "projectID", - ResourceOwner: "orgID", - PublicKeys: nil, + AppID: "appID", + ClientID: "clientID", + HashedSecret: "secret", + AppType: AppTypeOIDC, + ProjectID: "projectID", + ResourceOwner: "orgID", + ProjectRoleAssertion: true, + PublicKeys: nil, }, }, { @@ -70,17 +71,18 @@ func TestQueries_GetIntrospectionClientByID(t *testing.T) { getKeys: true, }, mock: mockQuery(expQuery, - []string{"app_id", "client_id", "client_secret", "app_type", "project_id", "resource_owner", "public_keys"}, - []driver.Value{"appID", "clientID", "", "oidc", "projectID", "orgID", encPubkeys}, + []string{"app_id", "client_id", "client_secret", "app_type", "project_id", "resource_owner", "project_role_assertion", "public_keys"}, + []driver.Value{"appID", "clientID", "", "oidc", "projectID", "orgID", true, encPubkeys}, "instanceID", "clientID", true), want: &IntrospectionClient{ - AppID: "appID", - ClientID: "clientID", - HashedSecret: "", - AppType: AppTypeOIDC, - ProjectID: "projectID", - ResourceOwner: "orgID", - PublicKeys: pubkeys, + AppID: "appID", + ClientID: "clientID", + HashedSecret: "", + AppType: AppTypeOIDC, + ProjectID: "projectID", + ResourceOwner: "orgID", + ProjectRoleAssertion: true, + PublicKeys: pubkeys, }, }, } diff --git a/internal/query/oidc_client.go b/internal/query/oidc_client.go index 67a9c7d5eb..6669b398b5 100644 --- a/internal/query/oidc_client.go +++ b/internal/query/oidc_client.go @@ -35,11 +35,12 @@ type OIDCClient struct { AdditionalOrigins []string `json:"additional_origins,omitempty"` PublicKeys map[string][]byte `json:"public_keys,omitempty"` ProjectID string `json:"project_id,omitempty"` + ProjectRoleAssertion bool `json:"project_role_assertion,omitempty"` ProjectRoleKeys []string `json:"project_role_keys,omitempty"` Settings *OIDCSettings `json:"settings,omitempty"` } -//go:embed embed/oidc_client_by_id.sql +//go:embed oidc_client_by_id.sql var oidcClientQuery string func (q *Queries) GetOIDCClientByID(ctx context.Context, clientID string, getKeys bool) (client *OIDCClient, err error) { diff --git a/internal/query/embed/oidc_client_by_id.sql b/internal/query/oidc_client_by_id.sql similarity index 84% rename from internal/query/embed/oidc_client_by_id.sql rename to internal/query/oidc_client_by_id.sql index 07f45cf68f..3a0a0a0c95 100644 --- a/internal/query/embed/oidc_client_by_id.sql +++ b/internal/query/oidc_client_by_id.sql @@ -1,14 +1,12 @@ ---deallocate q; ---prepare q(text, text, boolean) as - with client as ( select c.instance_id, - c.app_id, c.client_id, c.client_secret, c.redirect_uris, c.response_types, c.grant_types, + c.app_id, a.state, c.client_id, c.client_secret, c.redirect_uris, c.response_types, c.grant_types, c.application_type, c.auth_method_type, c.post_logout_redirect_uris, c.is_dev_mode, c.access_token_type, c.access_token_role_assertion, c.id_token_role_assertion, - c.id_token_userinfo_assertion, c.clock_skew, c.additional_origins, a.project_id, a.state + c.id_token_userinfo_assertion, c.clock_skew, c.additional_origins, a.project_id, p.project_role_assertion from projections.apps7_oidc_configs c join projections.apps7 a on a.id = c.app_id and a.instance_id = c.instance_id + join projections.projects4 p on p.id = a.project_id and p.instance_id = a.instance_id where c.instance_id = $1 and c.client_id = $2 ), @@ -45,7 +43,5 @@ select row_to_json(r) as client from ( from client c left join roles r on r.project_id = c.project_id left join keys k on k.client_id = c.client_id - left join settings s on s.instance_id = s.instance_id + left join settings s on s.instance_id = c.instance_id ) r; - ---execute q('230690539048009730', '236647088211951618@tests', true); \ No newline at end of file diff --git a/internal/query/oidc_client_test.go b/internal/query/oidc_client_test.go index bfbbe74098..93bd428015 100644 --- a/internal/query/oidc_client_test.go +++ b/internal/query/oidc_client_test.go @@ -80,6 +80,7 @@ low2kyJov38V4Uk2I8kuXpLcnrpw5Tio2ooiUE27b0vHZqBKOei9Uo88qCrn3EKx ClockSkew: 1000000000, AdditionalOrigins: []string{"https://example.com"}, ProjectID: "236645808328409090", + ProjectRoleAssertion: true, PublicKeys: map[string][]byte{"236647201860747266": []byte(pubkey)}, ProjectRoleKeys: []string{"role1", "role2"}, Settings: &OIDCSettings{ @@ -112,6 +113,7 @@ low2kyJov38V4Uk2I8kuXpLcnrpw5Tio2ooiUE27b0vHZqBKOei9Uo88qCrn3EKx AdditionalOrigins: nil, PublicKeys: nil, ProjectID: "236645808328409090", + ProjectRoleAssertion: true, ProjectRoleKeys: []string{"role1", "role2"}, Settings: &OIDCSettings{ AccessTokenLifetime: 43200000000000, @@ -143,6 +145,7 @@ low2kyJov38V4Uk2I8kuXpLcnrpw5Tio2ooiUE27b0vHZqBKOei9Uo88qCrn3EKx AdditionalOrigins: nil, PublicKeys: nil, ProjectID: "236645808328409090", + ProjectRoleAssertion: false, ProjectRoleKeys: []string{"role1", "role2"}, Settings: &OIDCSettings{ AccessTokenLifetime: 43200000000000, @@ -179,6 +182,7 @@ low2kyJov38V4Uk2I8kuXpLcnrpw5Tio2ooiUE27b0vHZqBKOei9Uo88qCrn3EKx AdditionalOrigins: nil, PublicKeys: nil, ProjectID: "239520764276178946", + ProjectRoleAssertion: false, ProjectRoleKeys: nil, Settings: nil, }, diff --git a/internal/query/testdata/oidc_client_jwt.json b/internal/query/testdata/oidc_client_jwt.json index df871815dd..1bca6044d4 100644 --- a/internal/query/testdata/oidc_client_jwt.json +++ b/internal/query/testdata/oidc_client_jwt.json @@ -1,6 +1,7 @@ { "instance_id": "230690539048009730", "app_id": "236647088211886082", + "state": 1, "client_id": "236647088211951618@tests", "client_secret": null, "redirect_uris": ["http://localhost:9999/auth/callback"], @@ -17,7 +18,7 @@ "clock_skew": 1000000000, "additional_origins": ["https://example.com"], "project_id": "236645808328409090", - "state": 1, + "project_role_assertion": true, "project_role_keys": ["role1", "role2"], "public_keys": { "236647201860747266": "LS0tLS1CRUdJTiBSU0EgUFVCTElDIEtFWS0tLS0tCk1JSUJJakFOQmdrcWhraUc5dzBCQVFFRkFB\nT0NBUThBTUlJQkNnS0NBUUVBMnVmQUwxYjcyYkl5MWFyK1dzNmIKR29oSkpRRkI3ZGZSYXBEcWVx\nTThVa3A2Q1ZkUHpxL3BPejF2aUFxNTB5eldaSnJ5Risyd3NoRkFLR0Y5QTIvQgoyWWY5YkpYUFov\nS2JrRnJZVDNOVHZZRGt2bGFTVGw5bU1uenJVMjlzNDhGMVBUV0tmQitDM2FNc09FRzFCdWZWCnM2\nM3FGNG5yRVBqU2JobGpJY285RlpxNFhwcEl6aE1RMGZEZEEvK1h5Z0NKcXZ1YUwwTGliTTFLcmxV\nZG51NzEKWWVraFNKakVQbnZPaXNYSWs0SVh5d29HSU93dGp4a0R2Tkl0UXZhTVZsZHI0L2tiNnV2\nYmdkV3dxNUV3QlpYcQpsb3cya3lKb3YzOFY0VWsySThrdVhwTGNucnB3NVRpbzJvb2lVRTI3YjB2\nSFpxQktPZWk5VW84OHFDcm4zRUt4CjZRSURBUUFCCi0tLS0tRU5EIFJTQSBQVUJMSUMgS0VZLS0t\nLS0K" diff --git a/internal/query/testdata/oidc_client_no_settings.json b/internal/query/testdata/oidc_client_no_settings.json index 83d810d669..59aff6ea42 100644 --- a/internal/query/testdata/oidc_client_no_settings.json +++ b/internal/query/testdata/oidc_client_no_settings.json @@ -1,6 +1,7 @@ { "instance_id": "239520764275982338", "app_id": "239520764276441090", + "state": 1, "client_id": "239520764779364354@zitadel", "client_secret": null, "redirect_uris": [ @@ -23,7 +24,7 @@ "clock_skew": 0, "additional_origins": null, "project_id": "239520764276178946", - "state": 1, + "project_role_assertion": false, "project_role_keys": null, "public_keys": null, "settings": null diff --git a/internal/query/testdata/oidc_client_public.json b/internal/query/testdata/oidc_client_public.json index 47cf750c8b..020c60311b 100644 --- a/internal/query/testdata/oidc_client_public.json +++ b/internal/query/testdata/oidc_client_public.json @@ -1,6 +1,7 @@ { "instance_id": "230690539048009730", "app_id": "236646457053020162", + "state": 1, "client_id": "236646457053085698@tests", "client_secret": null, "redirect_uris": ["http://localhost:9999/auth/callback"], @@ -17,7 +18,7 @@ "clock_skew": 0, "additional_origins": null, "project_id": "236645808328409090", - "state": 1, + "project_role_assertion": true, "project_role_keys": ["role1", "role2"], "public_keys": null, "settings": { diff --git a/internal/query/testdata/oidc_client_secret.json b/internal/query/testdata/oidc_client_secret.json index d12544f23d..0fb1d6f830 100644 --- a/internal/query/testdata/oidc_client_secret.json +++ b/internal/query/testdata/oidc_client_secret.json @@ -1,6 +1,7 @@ { "instance_id": "230690539048009730", "app_id": "236646858984783874", + "state": 1, "client_id": "236646858984849410@tests", "client_secret": "$2a$14$OzZ0XEZZEtD13py/EPba2evsS6WcKZ5orVMj9pWHEGEHmLu2h3PFq", "redirect_uris": ["http://localhost:9999/auth/callback"], @@ -17,7 +18,7 @@ "clock_skew": 0, "additional_origins": null, "project_id": "236645808328409090", - "state": 1, + "project_role_assertion": false, "project_role_keys": ["role1", "role2"], "public_keys": null, "settings": { diff --git a/internal/query/userinfo.go b/internal/query/userinfo.go index 2e2c27f9bc..3231817511 100644 --- a/internal/query/userinfo.go +++ b/internal/query/userinfo.go @@ -29,11 +29,13 @@ var oidcUserInfoTriggerHandlers = sync.OnceValue(func() []*handler.Handler { } }) +// TriggerOIDCUserInfoProjections triggers all projections +// relevant to userinfo queries concurrently. func TriggerOIDCUserInfoProjections(ctx context.Context) { triggerBatch(ctx, oidcUserInfoTriggerHandlers()...) } -//go:embed embed/userinfo_by_id.sql +//go:embed userinfo_by_id.sql var oidcUserInfoQuery string func (q *Queries) GetOIDCUserInfo(ctx context.Context, userID string, roleAudience []string) (_ *OIDCUserInfo, err error) { @@ -68,3 +70,25 @@ type UserInfoOrg struct { Name string `json:"name,omitempty"` PrimaryDomain string `json:"primary_domain,omitempty"` } + +//go:embed userinfo_client_by_id.sql +var oidcUserinfoClientQuery string + +func (q *Queries) GetOIDCUserinfoClientByID(ctx context.Context, clientID string) (projectID string, projectRoleAssertion bool, err error) { + ctx, span := tracing.NewSpan(ctx) + defer func() { span.EndWithError(err) }() + + scan := func(row *sql.Row) error { + err := row.Scan(&projectID, &projectRoleAssertion) + return err + } + + err = q.client.QueryRowContext(ctx, scan, oidcUserinfoClientQuery, authz.GetInstance(ctx).InstanceID(), clientID) + if errors.Is(err, sql.ErrNoRows) { + return "", false, zerrors.ThrowNotFound(err, "QUERY-beeW8", "Errors.App.NotFound") + } + if err != nil { + return "", false, zerrors.ThrowInternal(err, "QUERY-Ais4r", "Errors.Internal") + } + return projectID, projectRoleAssertion, nil +} diff --git a/internal/query/embed/userinfo_by_id.sql b/internal/query/userinfo_by_id.sql similarity index 100% rename from internal/query/embed/userinfo_by_id.sql rename to internal/query/userinfo_by_id.sql diff --git a/internal/query/userinfo_client_by_id.sql b/internal/query/userinfo_client_by_id.sql new file mode 100644 index 0000000000..615709c6df --- /dev/null +++ b/internal/query/userinfo_client_by_id.sql @@ -0,0 +1,6 @@ +select a.project_id, p.project_role_assertion +from projections.apps7_oidc_configs c +join projections.apps7 a on a.id = c.app_id and a.instance_id = c.instance_id +join projections.projects4 p on p.id = a.project_id and p.instance_id = a.instance_id +where c.instance_id = $1 + and c.client_id = $2; diff --git a/internal/query/userinfo_test.go b/internal/query/userinfo_test.go index 04a9edccb7..29d94d0baf 100644 --- a/internal/query/userinfo_test.go +++ b/internal/query/userinfo_test.go @@ -338,3 +338,50 @@ func TestQueries_GetOIDCUserInfo(t *testing.T) { }) } } + +func TestQueries_GetOIDCUserinfoClientByID(t *testing.T) { + expQuery := regexp.QuoteMeta(oidcUserinfoClientQuery) + cols := []string{"project_id", "project_role_assertion"} + + tests := []struct { + name string + mock sqlExpectation + wantProjectID string + wantProjectRoleAssertion bool + wantErr error + }{ + { + name: "no rows", + mock: mockQueryErr(expQuery, sql.ErrNoRows, "instanceID", "clientID"), + wantErr: zerrors.ThrowNotFound(sql.ErrNoRows, "QUERY-beeW8", "Errors.App.NotFound"), + }, + { + name: "internal error", + mock: mockQueryErr(expQuery, sql.ErrConnDone, "instanceID", "clientID"), + wantErr: zerrors.ThrowInternal(sql.ErrConnDone, "QUERY-Ais4r", "Errors.Internal"), + }, + { + name: "found", + mock: mockQuery(expQuery, cols, []driver.Value{"projectID", true}, "instanceID", "clientID"), + wantProjectID: "projectID", + wantProjectRoleAssertion: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + execMock(t, tt.mock, func(db *sql.DB) { + q := &Queries{ + client: &database.DB{ + DB: db, + Database: &prepareDB{}, + }, + } + ctx := authz.NewMockContext("instanceID", "orgID", "loginClient") + gotProjectID, gotProjectRoleAssertion, err := q.GetOIDCUserinfoClientByID(ctx, "clientID") + require.ErrorIs(t, err, tt.wantErr) + assert.Equal(t, tt.wantProjectID, gotProjectID) + assert.Equal(t, tt.wantProjectRoleAssertion, gotProjectRoleAssertion) + }) + }) + } +}