mirror of
https://github.com/zitadel/zitadel.git
synced 2024-12-04 23:45:07 +00:00
perf(oidc): remove get user by ID from jwt profile grant (#8580)
# Which Problems Are Solved Improve performance by removing a GetUserByID call. The call also executed a Trigger on projections, which significantly impacted concurrent requests. # How the Problems Are Solved Token creation needs information from the user, such as the resource owner and access token type. For client credentials this is solved in a single search. By getting the user by username (`client_id`), the user details and secret were obtained in a single query. After that verification and token creation can proceed. For JWT profile it is a bit more complex. We didn't know anything about the user until after JWT verification. The verification did a query for the AuthN key and after that we did a GetUserByID to get remaining details. This change uses a joined query when the OIDC library calls the `GetKeyByIDAndClientID` method on the token storage. The found user details are set to the verifieer object and returned after verification is completed. It is safe because the `jwtProfileKeyStorage` is a single-use object as a wrapper around `query.Queries`. This way getting the public key and user details are obtained in a single query. # Additional Changes - Correctly set the `client_id` field with machine's username. # Additional Context - Related to: https://github.com/zitadel/zitadel/issues/8352
This commit is contained in:
parent
3aba942162
commit
58a7eb1f26
@ -1043,11 +1043,11 @@ func (s *Server) verifyClientSecret(ctx context.Context, client *query.OIDCClien
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) checkOrgScopes(ctx context.Context, user *query.User, scopes []string) ([]string, error) {
|
||||
func (s *Server) checkOrgScopes(ctx context.Context, resourceOwner string, scopes []string) ([]string, error) {
|
||||
if slices.ContainsFunc(scopes, func(scope string) bool {
|
||||
return strings.HasPrefix(scope, domain.OrgDomainPrimaryScope)
|
||||
}) {
|
||||
org, err := s.query.OrgByID(ctx, false, user.ResourceOwner)
|
||||
org, err := s.query.OrgByID(ctx, false, resourceOwner)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -1060,7 +1060,7 @@ func (s *Server) checkOrgScopes(ctx context.Context, user *query.User, scopes []
|
||||
}
|
||||
return slices.DeleteFunc(scopes, func(scope string) bool {
|
||||
if orgID, ok := strings.CutPrefix(scope, domain.OrgIDScope); ok {
|
||||
return orgID != user.ResourceOwner
|
||||
return orgID != resourceOwner
|
||||
}
|
||||
return false
|
||||
}), nil
|
||||
|
@ -9,7 +9,7 @@ import (
|
||||
"github.com/zitadel/oidc/v3/pkg/op"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
@ -56,25 +56,29 @@ func (s *Server) clientCredentialsAuth(ctx context.Context, clientID, clientSecr
|
||||
|
||||
s.command.MachineSecretCheckSucceeded(ctx, user.ID, user.ResourceOwner, updated)
|
||||
return &clientCredentialsClient{
|
||||
id: clientID,
|
||||
user: user,
|
||||
clientID: user.Username,
|
||||
userID: user.ID,
|
||||
resourceOwner: user.ResourceOwner,
|
||||
tokenType: user.Machine.AccessTokenType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type clientCredentialsClient struct {
|
||||
id string
|
||||
user *query.User
|
||||
clientID string
|
||||
userID string
|
||||
resourceOwner string
|
||||
tokenType domain.OIDCTokenType
|
||||
}
|
||||
|
||||
// AccessTokenType returns the AccessTokenType for the token to be created because of the client credentials request
|
||||
// machine users currently only have opaque tokens ([op.AccessTokenTypeBearer])
|
||||
func (c *clientCredentialsClient) AccessTokenType() op.AccessTokenType {
|
||||
return accessTokenTypeToOIDC(c.user.Machine.AccessTokenType)
|
||||
return accessTokenTypeToOIDC(c.tokenType)
|
||||
}
|
||||
|
||||
// GetID returns the client_id (username of the machine user) for the token to be created because of the client credentials request
|
||||
func (c *clientCredentialsClient) GetID() string {
|
||||
return c.id
|
||||
return c.clientID
|
||||
}
|
||||
|
||||
// RedirectURIs returns nil as there are no redirect uris
|
||||
|
@ -26,15 +26,15 @@ func (s *Server) ClientCredentialsExchange(ctx context.Context, r *op.ClientRequ
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
scope, err = s.checkOrgScopes(ctx, client.user, scope)
|
||||
scope, err = s.checkOrgScopes(ctx, client.resourceOwner, scope)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
session, err := s.command.CreateOIDCSession(ctx,
|
||||
client.user.ID,
|
||||
client.user.ResourceOwner,
|
||||
"",
|
||||
client.userID,
|
||||
client.resourceOwner,
|
||||
client.clientID,
|
||||
scope,
|
||||
domain.AddAudScopeToAudience(ctx, nil, r.Data.Scope),
|
||||
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
|
||||
|
@ -21,28 +21,30 @@ func (s *Server) JWTProfile(ctx context.Context, r *op.Request[oidc.JWTProfileGr
|
||||
err = oidcError(err)
|
||||
}()
|
||||
|
||||
user, jwtReq, err := s.verifyJWTProfile(ctx, r.Data)
|
||||
user, err := s.verifyJWTProfile(ctx, r.Data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client := &clientCredentialsClient{
|
||||
id: jwtReq.Subject,
|
||||
user: user,
|
||||
clientID: user.Username,
|
||||
userID: user.UserID,
|
||||
resourceOwner: user.ResourceOwner,
|
||||
tokenType: user.TokenType,
|
||||
}
|
||||
scope, err := op.ValidateAuthReqScopes(client, r.Data.Scope)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
scope, err = s.checkOrgScopes(ctx, client.user, scope)
|
||||
scope, err = s.checkOrgScopes(ctx, client.resourceOwner, scope)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
session, err := s.command.CreateOIDCSession(ctx,
|
||||
user.ID,
|
||||
user.ResourceOwner,
|
||||
"",
|
||||
client.userID,
|
||||
client.resourceOwner,
|
||||
client.clientID,
|
||||
scope,
|
||||
domain.AddAudScopeToAudience(ctx, nil, r.Data.Scope),
|
||||
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePrivateKey},
|
||||
@ -61,37 +63,33 @@ func (s *Server) JWTProfile(ctx context.Context, r *op.Request[oidc.JWTProfileGr
|
||||
return response(s.accessTokenResponseFromSession(ctx, client, session, "", "", false, true, true, false))
|
||||
}
|
||||
|
||||
func (s *Server) verifyJWTProfile(ctx context.Context, req *oidc.JWTProfileGrantRequest) (user *query.User, tokenRequest *oidc.JWTTokenRequest, err error) {
|
||||
func (s *Server) verifyJWTProfile(ctx context.Context, req *oidc.JWTProfileGrantRequest) (_ *query.AuthNKeyUser, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
checkSubject := func(jwt *oidc.JWTTokenRequest) (err error) {
|
||||
user, err = s.query.GetUserByID(ctx, true, jwt.Subject)
|
||||
return err
|
||||
}
|
||||
storage := &jwtProfileKeyStorage{query: s.query}
|
||||
verifier := op.NewJWTProfileVerifier(
|
||||
&jwtProfileKeyStorage{query: s.query},
|
||||
op.IssuerFromContext(ctx),
|
||||
storage, op.IssuerFromContext(ctx),
|
||||
time.Hour, time.Second,
|
||||
op.SubjectCheck(checkSubject),
|
||||
)
|
||||
tokenRequest, err = op.VerifyJWTAssertion(ctx, req.Assertion, verifier)
|
||||
_, err = op.VerifyJWTAssertion(ctx, req.Assertion, verifier)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, err
|
||||
}
|
||||
return user, tokenRequest, nil
|
||||
return storage.user, nil
|
||||
}
|
||||
|
||||
type jwtProfileKeyStorage struct {
|
||||
query *query.Queries
|
||||
user *query.AuthNKeyUser // only populated after GetKeyByIDAndClientID is called
|
||||
}
|
||||
|
||||
func (s *jwtProfileKeyStorage) GetKeyByIDAndClientID(ctx context.Context, keyID, userID string) (*jose.JSONWebKey, error) {
|
||||
publicKeyData, err := s.query.GetAuthNKeyPublicKeyByIDAndIdentifier(ctx, keyID, userID)
|
||||
func (s *jwtProfileKeyStorage) GetKeyByIDAndClientID(ctx context.Context, keyID, userID string) (_ *jose.JSONWebKey, err error) {
|
||||
s.user, err = s.query.GetAuthNKeyUser(ctx, keyID, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
publicKey, err := crypto.BytesToPublicKey(publicKeyData)
|
||||
publicKey, err := crypto.BytesToPublicKey(s.user.PublicKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -3,6 +3,7 @@ package query
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
_ "embed"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
@ -249,6 +250,44 @@ func NewAuthNKeyObjectIDQuery(id string) (SearchQuery, error) {
|
||||
return NewTextQuery(AuthNKeyColumnObjectID, id, TextEquals)
|
||||
}
|
||||
|
||||
//go:embed authn_key_user.sql
|
||||
var authNKeyUserQuery string
|
||||
|
||||
type AuthNKeyUser struct {
|
||||
UserID string
|
||||
ResourceOwner string
|
||||
Username string
|
||||
TokenType domain.OIDCTokenType
|
||||
PublicKey []byte
|
||||
}
|
||||
|
||||
func (q *Queries) GetAuthNKeyUser(ctx context.Context, keyID, userID string) (_ *AuthNKeyUser, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
dst := new(AuthNKeyUser)
|
||||
err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
|
||||
return row.Scan(
|
||||
&dst.UserID,
|
||||
&dst.ResourceOwner,
|
||||
&dst.Username,
|
||||
&dst.TokenType,
|
||||
&dst.PublicKey,
|
||||
)
|
||||
},
|
||||
authNKeyUserQuery,
|
||||
authz.GetInstance(ctx).InstanceID(),
|
||||
keyID, userID,
|
||||
)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, zerrors.ThrowNotFound(err, "QUERY-Tha6f", "Errors.AuthNKey.NotFound")
|
||||
}
|
||||
return nil, zerrors.ThrowInternal(err, "QUERY-aen2A", "Errors.Internal")
|
||||
}
|
||||
return dst, nil
|
||||
}
|
||||
|
||||
func prepareAuthNKeysQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(rows *sql.Rows) (*AuthNKeys, error)) {
|
||||
return sq.Select(
|
||||
AuthNKeyColumnID.identifier(),
|
||||
|
@ -8,6 +8,11 @@ import (
|
||||
"regexp"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
@ -470,3 +475,65 @@ func Test_AuthNKeyPrepares(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueries_GetAuthNKeyUser(t *testing.T) {
|
||||
expQuery := regexp.QuoteMeta(authNKeyUserQuery)
|
||||
cols := []string{"user_id", "resource_owner", "username", "access_token_type", "public_key"}
|
||||
pubkey := []byte(`-----BEGIN RSA PUBLIC KEY-----
|
||||
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA2ufAL1b72bIy1ar+Ws6b
|
||||
GohJJQFB7dfRapDqeqM8Ukp6CVdPzq/pOz1viAq50yzWZJryF+2wshFAKGF9A2/B
|
||||
2Yf9bJXPZ/KbkFrYT3NTvYDkvlaSTl9mMnzrU29s48F1PTWKfB+C3aMsOEG1BufV
|
||||
s63qF4nrEPjSbhljIco9FZq4XppIzhMQ0fDdA/+XygCJqvuaL0LibM1KrlUdnu71
|
||||
YekhSJjEPnvOisXIk4IXywoGIOwtjxkDvNItQvaMVldr4/kb6uvbgdWwq5EwBZXq
|
||||
low2kyJov38V4Uk2I8kuXpLcnrpw5Tio2ooiUE27b0vHZqBKOei9Uo88qCrn3EKx
|
||||
6QIDAQAB
|
||||
-----END RSA PUBLIC KEY-----`)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
mock sqlExpectation
|
||||
want *AuthNKeyUser
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "no rows",
|
||||
mock: mockQueryErr(expQuery, sql.ErrNoRows, "instanceID", "keyID", "userID"),
|
||||
wantErr: zerrors.ThrowNotFound(sql.ErrNoRows, "QUERY-Tha6f", "Errors.AuthNKey.NotFound"),
|
||||
},
|
||||
{
|
||||
name: "internal error",
|
||||
mock: mockQueryErr(expQuery, sql.ErrConnDone, "instanceID", "keyID", "userID"),
|
||||
wantErr: zerrors.ThrowInternal(sql.ErrConnDone, "QUERY-aen2A", "Errors.Internal"),
|
||||
},
|
||||
{
|
||||
name: "success",
|
||||
mock: mockQuery(expQuery, cols,
|
||||
[]driver.Value{"userID", "orgID", "username", domain.OIDCTokenTypeJWT, pubkey},
|
||||
"instanceID", "keyID", "userID",
|
||||
),
|
||||
want: &AuthNKeyUser{
|
||||
UserID: "userID",
|
||||
ResourceOwner: "orgID",
|
||||
Username: "username",
|
||||
TokenType: domain.OIDCTokenTypeJWT,
|
||||
PublicKey: pubkey,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
execMock(t, tt.mock, func(db *sql.DB) {
|
||||
q := &Queries{
|
||||
client: &database.DB{
|
||||
DB: db,
|
||||
Database: &prepareDB{},
|
||||
},
|
||||
}
|
||||
ctx := authz.NewMockContext("instanceID", "orgID", "userID")
|
||||
got, err := q.GetAuthNKeyUser(ctx, "keyID", "userID")
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
11
internal/query/authn_key_user.sql
Normal file
11
internal/query/authn_key_user.sql
Normal file
@ -0,0 +1,11 @@
|
||||
select u.id as user_id, u.resource_owner, u.username, m.access_token_type, k.public_key
|
||||
from projections.authn_keys2 k
|
||||
join projections.users13 u
|
||||
on k.instance_id = u.instance_id
|
||||
and k.identifier = u.id
|
||||
join projections.users13_machines m
|
||||
on u.instance_id = m.instance_id
|
||||
and u.id = m.user_id
|
||||
where k.instance_id = $1
|
||||
and k.id = $2
|
||||
and u.id = $3;
|
Loading…
Reference in New Issue
Block a user