From 25a2cd4aa4580f18f903d1d64eacb772982ddd3c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20M=C3=B6hlmann?= Date: Thu, 4 Apr 2024 16:11:01 +0300 Subject: [PATCH] feat(oidc): optimize the userinfo endpoint --- internal/api/oidc/client_integration_test.go | 11 --- internal/api/oidc/introspect.go | 1 - internal/api/oidc/oidc_integration_test.go | 6 +- internal/api/oidc/server.go | 7 -- internal/api/oidc/userinfo.go | 30 ++++++ .../api/oidc/userinfo_integration_test.go | 96 +++++++++++++++++++ internal/query/introspection.go | 2 + internal/query/userinfo.go | 2 + 8 files changed, 133 insertions(+), 22 deletions(-) create mode 100644 internal/api/oidc/userinfo_integration_test.go diff --git a/internal/api/oidc/client_integration_test.go b/internal/api/oidc/client_integration_test.go index 96812120d80..a292d8f07a2 100644 --- a/internal/api/oidc/client_integration_test.go +++ b/internal/api/oidc/client_integration_test.go @@ -172,17 +172,6 @@ func TestServer_Introspect(t *testing.T) { } } -func assertUserinfo(t *testing.T, userinfo *oidc.UserInfo) { - assert.Equal(t, User.GetUserId(), userinfo.Subject) - assert.Equal(t, "Mickey", userinfo.GivenName) - assert.Equal(t, "Mouse", userinfo.FamilyName) - assert.Equal(t, "Mickey Mouse", userinfo.Name) - assert.NotEmpty(t, userinfo.PreferredUsername) - assert.Equal(t, userinfo.PreferredUsername, userinfo.Email) - assert.False(t, bool(userinfo.EmailVerified)) - assertOIDCTime(t, userinfo.UpdatedAt, User.GetDetails().GetChangeDate().AsTime()) -} - func assertIntrospection( t *testing.T, introspection *oidc.IntrospectionResponse, diff --git a/internal/api/oidc/introspect.go b/internal/api/oidc/introspect.go index 04934e0dc80..b5263dceb12 100644 --- a/internal/api/oidc/introspect.go +++ b/internal/api/oidc/introspect.go @@ -29,7 +29,6 @@ func (s *Server) Introspect(ctx context.Context, r *op.Request[op.IntrospectionR return s.LegacyServer.Introspect(ctx, r) } if features.TriggerIntrospectionProjections { - // Execute all triggers in one concurrent sweep. query.TriggerIntrospectionProjections(ctx) } diff --git a/internal/api/oidc/oidc_integration_test.go b/internal/api/oidc/oidc_integration_test.go index 1e6f8ed1188..74e8e9d568d 100644 --- a/internal/api/oidc/oidc_integration_test.go +++ b/internal/api/oidc/oidc_integration_test.go @@ -39,17 +39,17 @@ const ( func TestMain(m *testing.M) { os.Exit(func() int { - ctx, errCtx, cancel := integration.Contexts(10 * time.Minute) + ctx, _, cancel := integration.Contexts(10 * time.Minute) defer cancel() Tester = integration.NewTester(ctx) defer Tester.Done() - CTX, _ = Tester.WithAuthorization(ctx, integration.OrgOwner), errCtx + CTX = Tester.WithAuthorization(ctx, integration.OrgOwner) User = Tester.CreateHumanUser(CTX) Tester.SetUserPassword(CTX, User.GetUserId(), integration.UserPassword, false) Tester.RegisterUserPasskey(CTX, User.GetUserId()) - CTXLOGIN, _ = Tester.WithAuthorization(ctx, integration.Login), errCtx + CTXLOGIN = Tester.WithAuthorization(ctx, integration.Login) return m.Run() }()) } diff --git a/internal/api/oidc/server.go b/internal/api/oidc/server.go index fe10db3219c..f2a487d03f8 100644 --- a/internal/api/oidc/server.go +++ b/internal/api/oidc/server.go @@ -188,13 +188,6 @@ func (s *Server) DeviceToken(ctx context.Context, r *op.ClientRequest[oidc.Devic return s.LegacyServer.DeviceToken(ctx, r) } -func (s *Server) UserInfo(ctx context.Context, r *op.Request[oidc.UserInfoRequest]) (_ *op.Response, err error) { - ctx, span := tracing.NewSpan(ctx) - defer func() { span.EndWithError(err) }() - - return s.LegacyServer.UserInfo(ctx, r) -} - func (s *Server) Revocation(ctx context.Context, r *op.ClientRequest[oidc.RevocationRequest]) (_ *op.Response, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() diff --git a/internal/api/oidc/userinfo.go b/internal/api/oidc/userinfo.go index a960c6ceca8..fdba9b7e4ed 100644 --- a/internal/api/oidc/userinfo.go +++ b/internal/api/oidc/userinfo.go @@ -5,19 +5,49 @@ import ( "encoding/base64" "encoding/json" "fmt" + "net/http" "slices" "strings" "github.com/dop251/goja" "github.com/zitadel/logging" "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/op" "github.com/zitadel/zitadel/internal/actions" "github.com/zitadel/zitadel/internal/actions/object" + "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/query" + "github.com/zitadel/zitadel/internal/telemetry/tracing" ) +func (s *Server) UserInfo(ctx context.Context, r *op.Request[oidc.UserInfoRequest]) (_ *op.Response, err error) { + ctx, span := tracing.NewSpan(ctx) + defer func() { + err = oidcError(err) + span.EndWithError(err) + }() + + features := authz.GetFeatures(ctx) + if features.LegacyIntrospection { + return s.LegacyServer.UserInfo(ctx, r) + } + if features.TriggerIntrospectionProjections { + query.TriggerOIDCUserInfoProjections(ctx) + } + + token, err := s.verifyAccessToken(ctx, r.Data.AccessToken) + if err != nil { + return nil, op.NewStatusError(oidc.ErrAccessDenied().WithDescription("access token invalid").WithParent(err), http.StatusUnauthorized) + } + userInfo, err := s.userInfo(ctx, token.userID, "", token.scope, token.audience) + if err != nil { + return nil, err + } + return op.NewResponse(userInfo), nil +} + func (s *Server) userInfo(ctx context.Context, userID, projectID string, scope, roleAudience []string) (_ *oidc.UserInfo, err error) { roleAudience, requestedRoles := prepareRoles(ctx, projectID, scope, roleAudience) qu, err := s.query.GetOIDCUserInfo(ctx, userID, roleAudience) diff --git a/internal/api/oidc/userinfo_integration_test.go b/internal/api/oidc/userinfo_integration_test.go new file mode 100644 index 00000000000..81a3dc2dbda --- /dev/null +++ b/internal/api/oidc/userinfo_integration_test.go @@ -0,0 +1,96 @@ +//go:build integration + +package oidc_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/zitadel/oidc/v3/pkg/client/rp" + "github.com/zitadel/oidc/v3/pkg/oidc" + + "github.com/zitadel/zitadel/internal/integration" + feature "github.com/zitadel/zitadel/pkg/grpc/feature/v2beta" + oidc_pb "github.com/zitadel/zitadel/pkg/grpc/oidc/v2beta" +) + +func TestServer_UserInfo(t *testing.T) { + iamOwnerCTX := Tester.WithAuthorization(CTX, integration.IAMOwner) + t.Cleanup(func() { + _, err := Tester.Client.FeatureV2.ResetInstanceFeatures(iamOwnerCTX, &feature.ResetInstanceFeaturesRequest{}) + require.NoError(t, err) + }) + tests := []struct { + name string + legacy bool + trigger bool + }{ + { + name: "legacy enabled", + legacy: true, + }, + { + name: "legacy and trigger disabled", + legacy: false, + trigger: false, + }, + { + name: "legacy disabled, trigger enabled", + legacy: false, + trigger: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := Tester.Client.FeatureV2.SetInstanceFeatures(iamOwnerCTX, &feature.SetInstanceFeaturesRequest{ + OidcLegacyIntrospection: &tt.legacy, + OidcTriggerIntrospectionProjections: &tt.trigger, + }) + require.NoError(t, err) + testServer_UserInfo(t) + }) + } +} + +func testServer_UserInfo(t *testing.T) { + clientID := createClient(t) + authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeOfflineAccess) + sessionID, sessionToken, startTime, changeTime := Tester.CreateVerifiedWebAuthNSession(t, CTXLOGIN, User.GetUserId()) + linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{ + AuthRequestId: authRequestID, + CallbackKind: &oidc_pb.CreateCallbackRequest_Session{ + Session: &oidc_pb.Session{ + SessionId: sessionID, + SessionToken: sessionToken, + }, + }, + }) + require.NoError(t, err) + + // code exchange + code := assertCodeResponse(t, linkResp.GetCallbackUrl()) + tokens, err := exchangeTokens(t, clientID, code, redirectURI) + require.NoError(t, err) + assertTokens(t, tokens, true) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) + + // test actual userinfo + provider, err := Tester.CreateRelyingParty(CTX, clientID, redirectURI) + require.NoError(t, err) + userinfo, err := rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.Subject, provider) + require.NoError(t, err) + assertUserinfo(t, userinfo) +} + +func assertUserinfo(t *testing.T, userinfo *oidc.UserInfo) { + assert.Equal(t, User.GetUserId(), userinfo.Subject) + assert.Equal(t, "Mickey", userinfo.GivenName) + assert.Equal(t, "Mouse", userinfo.FamilyName) + assert.Equal(t, "Mickey Mouse", userinfo.Name) + assert.NotEmpty(t, userinfo.PreferredUsername) + assert.Equal(t, userinfo.PreferredUsername, userinfo.Email) + assert.False(t, bool(userinfo.EmailVerified)) + assertOIDCTime(t, userinfo.UpdatedAt, User.GetDetails().GetChangeDate().AsTime()) +} diff --git a/internal/query/introspection.go b/internal/query/introspection.go index 0e190da25d7..1924a105ae6 100644 --- a/internal/query/introspection.go +++ b/internal/query/introspection.go @@ -26,6 +26,8 @@ var introspectionTriggerHandlers = sync.OnceValue(func() []*handler.Handler { ) }) +// TriggerIntrospectionProjections triggers all projections +// relevant to introspection queries concurrently. func TriggerIntrospectionProjections(ctx context.Context) { triggerBatch(ctx, introspectionTriggerHandlers()...) } diff --git a/internal/query/userinfo.go b/internal/query/userinfo.go index 2e2c27f9bce..62095011304 100644 --- a/internal/query/userinfo.go +++ b/internal/query/userinfo.go @@ -29,6 +29,8 @@ var oidcUserInfoTriggerHandlers = sync.OnceValue(func() []*handler.Handler { } }) +// TriggerOIDCUserInfoProjections triggers all projections +// relevant to userinfo queries concurrently. func TriggerOIDCUserInfoProjections(ctx context.Context) { triggerBatch(ctx, oidcUserInfoTriggerHandlers()...) }