From 66f91cdc4e9f46c8443880cf89c5062e2370050a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20M=C3=B6hlmann?= Date: Sun, 5 Nov 2023 13:18:17 +0200 Subject: [PATCH] client and project in single query --- internal/api/oidc/introspect.go | 63 +++++++++----- internal/api/oidc/key.go | 83 +++++++++++++------ internal/api/oidc/op.go | 3 +- internal/api/oidc/server.go | 3 + .../embed/introspection_client_by_id.sql | 24 ++++++ internal/query/introspection_client.go | 45 ++++++++++ 6 files changed, 173 insertions(+), 48 deletions(-) create mode 100644 internal/query/embed/introspection_client_by_id.sql create mode 100644 internal/query/introspection_client.go diff --git a/internal/api/oidc/introspect.go b/internal/api/oidc/introspect.go index 26688e529be..fa8ded77c93 100644 --- a/internal/api/oidc/introspect.go +++ b/internal/api/oidc/introspect.go @@ -2,6 +2,7 @@ package oidc import ( "context" + "database/sql" "errors" "slices" "strings" @@ -10,7 +11,10 @@ import ( "github.com/zitadel/oidc/v3/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/op" "github.com/zitadel/zitadel/internal/command" + "github.com/zitadel/zitadel/internal/crypto" errz "github.com/zitadel/zitadel/internal/errors" + "github.com/zitadel/zitadel/internal/query" + "github.com/zitadel/zitadel/internal/telemetry/tracing" ) func (s *Server) Introspect(ctx context.Context, r *op.Request[op.IntrospectionRequest]) (_ *op.Response, err error) { @@ -94,43 +98,56 @@ type instrospectionClientResult struct { } func (s *Server) instrospectionClientAuth(ctx context.Context, cc *op.ClientCredentials, rc chan<- *instrospectionClientResult) { - clientID := cc.ClientID + ctx, span := tracing.NewSpan(ctx) - if cc.ClientAssertion != "" { - verifier := op.NewJWTProfileVerifier(s.storage, op.IssuerFromContext(ctx), 1*time.Hour, time.Second) - profile, err := op.VerifyJWTAssertion(ctx, cc.ClientAssertion, verifier) + clientID, projectID, err := func() (string, string, error) { + client, err := s.clientFromCredentials(ctx, cc) if err != nil { - rc <- &instrospectionClientResult{ - err: oidc.ErrUnauthorizedClient().WithParent(err), - } - return + return "", "", err } - clientID = profile.Issuer - } else { - if err := s.storage.AuthorizeClientIDSecret(ctx, cc.ClientID, cc.ClientSecret); err != nil { - if err != nil { - rc <- &instrospectionClientResult{ - err: oidc.ErrUnauthorizedClient().WithParent(err), - } - return + + if cc.ClientAssertion != "" { + verifier := op.NewJWTProfileVerifierKeySet(keySetMap(client.PublicKeys), op.IssuerFromContext(ctx), time.Hour, time.Second) + if _, err := op.VerifyJWTAssertion(ctx, cc.ClientAssertion, verifier); err != nil { + return "", "", oidc.ErrUnauthorizedClient().WithParent(err) + } + } else { + if err := crypto.CompareHash(client.ClientSecret, []byte(cc.ClientSecret), s.hashAlg); err != nil { + return "", "", oidc.ErrUnauthorizedClient().WithParent(err) } } - } + return client.ClientID, client.ProjectID, nil + }() - // TODO: give clients their own aggregate, so we can skip this query - projectID, err := s.storage.query.ProjectIDFromClientID(ctx, clientID, false) - if err != nil { - rc <- &instrospectionClientResult{err: err} - return - } + span.EndWithError(err) rc <- &instrospectionClientResult{ clientID: clientID, projectID: projectID, + err: err, } } +// clientFromCredentials parses the client ID early, +// and makes a single query for the client for either auth methods. +func (s *Server) clientFromCredentials(ctx context.Context, cc *op.ClientCredentials) (client *query.IntrospectionClient, err error) { + if cc.ClientAssertion != "" { + claims := new(oidc.JWTTokenRequest) + if _, err := oidc.ParseToken(cc.ClientAssertion, claims); err != nil { + return nil, oidc.ErrUnauthorizedClient().WithParent(err) + } + client, err = s.storage.query.GetIntrospectionClientByID(ctx, claims.Issuer, true) + } else { + client, err = s.storage.query.GetIntrospectionClientByID(ctx, cc.ClientID, false) + } + if errors.Is(err, sql.ErrNoRows) { + return nil, oidc.ErrUnauthorizedClient().WithParent(err) + } + // any other error is regarded internal and should not be reported back to the client. + return client, err +} + type introspectionTokenResult struct { tokenID string userID string diff --git a/internal/api/oidc/key.go b/internal/api/oidc/key.go index bbc58b48263..9ded22c65df 100644 --- a/internal/api/oidc/key.go +++ b/internal/api/oidc/key.go @@ -20,14 +20,21 @@ import ( "github.com/zitadel/zitadel/internal/telemetry/tracing" ) -type keySet struct { +// keySetCache implements oidc.KeySet for Access Token verification. +// Public Keys are cached in a 2-dimentional map of Instance ID and Key ID. +// When a key is not present the queryKey function is called to obtain the key +// from the database. +type keySetCache struct { mtx sync.RWMutex instanceKeys map[string]map[string]query.PublicKey queryKey func(ctx context.Context, keyID string, current time.Time) (query.PublicKey, error) } -func newKeySet(background context.Context, purgeInterval time.Duration, queryKey func(ctx context.Context, keyID string, current time.Time) (query.PublicKey, error)) *keySet { - k := &keySet{ +// newKeySet initializes a keySetCache and starts a purging Go routine, +// which runs once every purgeInterval. +// When the passed context is done, the purge routine will terminate. +func newKeySet(background context.Context, purgeInterval time.Duration, queryKey func(ctx context.Context, keyID string, current time.Time) (query.PublicKey, error)) *keySetCache { + k := &keySetCache{ instanceKeys: make(map[string]map[string]query.PublicKey), queryKey: queryKey, } @@ -35,7 +42,7 @@ func newKeySet(background context.Context, purgeInterval time.Duration, queryKey return k } -func (v *keySet) purgeOnInterval(background context.Context, purgeInterval time.Duration) { +func (k *keySetCache) purgeOnInterval(background context.Context, purgeInterval time.Duration) { timer := time.NewTimer(purgeInterval) defer func() { if !timer.Stop() { @@ -43,50 +50,49 @@ func (v *keySet) purgeOnInterval(background context.Context, purgeInterval time. } }() -loop: for { select { case <-background.Done(): - break loop + return case <-timer.C: timer.Reset(purgeInterval) } // do the actual purging - v.mtx.Lock() - for instanceID, keys := range v.instanceKeys { + k.mtx.Lock() + for instanceID, keys := range k.instanceKeys { for keyID, key := range keys { if key.Expiry().Before(time.Now()) { delete(keys, keyID) } } if len(keys) == 0 { - delete(v.instanceKeys, instanceID) + delete(k.instanceKeys, instanceID) } } - v.mtx.Unlock() + k.mtx.Unlock() } } -func (v *keySet) setKey(instanceID, keyID string, key query.PublicKey) { - v.mtx.Lock() - defer v.mtx.Unlock() +func (k *keySetCache) setKey(instanceID, keyID string, key query.PublicKey) { + k.mtx.Lock() + defer k.mtx.Unlock() - if keys, ok := v.instanceKeys[instanceID]; ok { + if keys, ok := k.instanceKeys[instanceID]; ok { keys[keyID] = key return } - v.instanceKeys[instanceID] = map[string]query.PublicKey{keyID: key} + k.instanceKeys[instanceID] = map[string]query.PublicKey{keyID: key} } -func (v *keySet) getKey(ctx context.Context, keyID string, current time.Time) (*jose.JSONWebKey, error) { +func (k *keySetCache) getKey(ctx context.Context, keyID string, current time.Time) (*jose.JSONWebKey, error) { instanceID := authz.GetInstance(ctx).InstanceID() - v.mtx.RLock() - key, ok := v.instanceKeys[instanceID][keyID] - v.mtx.RUnlock() + k.mtx.RLock() + key, ok := k.instanceKeys[instanceID][keyID] + k.mtx.RUnlock() if ok { if key.Expiry().After(current) { @@ -95,24 +101,24 @@ func (v *keySet) getKey(ctx context.Context, keyID string, current time.Time) (* return nil, errors.ThrowInvalidArgument(nil, "OIDC-Zoh9E", "Errors.Key.ExpireBeforeNow") } - key, err := v.queryKey(ctx, keyID, current) + key, err := k.queryKey(ctx, keyID, current) if err != nil { return nil, err } - v.setKey(instanceID, keyID, key) + k.setKey(instanceID, keyID, key) return jsonWebkey(key), nil } // VerifySignature implements the oidc.KeySet interface. -func (v *keySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) { +func (k *keySetCache) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) { if len(jws.Signatures) != 1 { return nil, errors.ThrowInvalidArgument(nil, "OIDC-Gid9s", "Errors.Token.Invalid") } - key, err := v.getKey(ctx, jws.Signatures[0].Header.KeyID, time.Now()) + key, err := k.getKey(ctx, jws.Signatures[0].Header.KeyID, time.Now()) if err != nil { return nil, err } - return jws.Verify(&key) + return jws.Verify(key) } func jsonWebkey(key query.PublicKey) *jose.JSONWebKey { @@ -124,6 +130,35 @@ func jsonWebkey(key query.PublicKey) *jose.JSONWebKey { } } +// keySetMap is a mapping of key IDs to public key data. +type keySetMap map[string][]byte + +// getKey finds the keyID and parses the public key data +// into a JSONWebKey. +func (k keySetMap) getKey(keyID string) (*jose.JSONWebKey, error) { + pubKey, err := crypto.BytesToPublicKey(k[keyID]) + if err != nil { + return nil, err + } + return &jose.JSONWebKey{ + Key: pubKey, + KeyID: keyID, + Use: "sig", + }, nil +} + +// VerifySignature implements the oidc.KeySet interface. +func (k keySetMap) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) { + if len(jws.Signatures) != 1 { + return nil, errors.ThrowInvalidArgument(nil, "OIDC-Eeth6", "Errors.Token.Invalid") + } + key, err := k.getKey(jws.Signatures[0].Header.KeyID) + if err != nil { + return nil, err + } + return jws.Verify(key) +} + const ( locksTable = "projections.locks" signingKey = "signing_key" diff --git a/internal/api/oidc/op.go b/internal/api/oidc/op.go index 479d4e755fa..fd486cf355c 100644 --- a/internal/api/oidc/op.go +++ b/internal/api/oidc/op.go @@ -68,7 +68,7 @@ type OPStorage struct { command *command.Commands query *query.Queries eventstore *eventstore.Eventstore - keySet *keySet + keySet *keySetCache defaultLoginURL string defaultLoginURLV2 string defaultLogoutURLV2 string @@ -122,6 +122,7 @@ func NewServer( server := &Server{ storage: storage, LegacyServer: op.NewLegacyServer(provider, endpoints(config.CustomEndpoints)), + hashAlg: crypto.NewBCrypt(10), // as we are only verifying in oidc, the cost is already part of the hash string and the config here is irrelevant. } metricTypes := []metrics.MetricType{metrics.MetricTypeRequestCount, metrics.MetricTypeStatusCode, metrics.MetricTypeTotalCount} server.Handler = op.RegisterLegacyServer(server, op.WithHTTPMiddleware( diff --git a/internal/api/oidc/server.go b/internal/api/oidc/server.go index cd05e24e533..bdd0ae7df19 100644 --- a/internal/api/oidc/server.go +++ b/internal/api/oidc/server.go @@ -6,6 +6,7 @@ import ( "github.com/zitadel/oidc/v3/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/op" + "github.com/zitadel/zitadel/internal/crypto" "github.com/zitadel/zitadel/internal/telemetry/tracing" ) @@ -13,6 +14,8 @@ type Server struct { http.Handler storage *OPStorage *op.LegacyServer + + hashAlg crypto.HashAlgorithm } func endpoints(endpointConfig *EndpointConfig) op.Endpoints { diff --git a/internal/query/embed/introspection_client_by_id.sql b/internal/query/embed/introspection_client_by_id.sql new file mode 100644 index 00000000000..82a86ac87cc --- /dev/null +++ b/internal/query/embed/introspection_client_by_id.sql @@ -0,0 +1,24 @@ +with config as ( + select app_id, client_id, client_secret + from projections.apps5_api_configs + where instance_id = $1 + and client_id = $2 + union + select app_id, client_id, client_secret + from projections.apps5_oidc_configs + where instance_id = $1 + and client_id = $2 +), +keys as ( + select identifier as client_id, json_object_agg(id, public_key) as public_keys + from projections.authn_keys2 + where $3 = true -- when argument is false, don't waste time on trying to query for keys. + and instance_id = $1 + and identifier = $2 + and expiration > current_timestamp + group by identifier +) +select apps.project_id, config.client_secret, keys.public_keys from config +join projections.apps5 apps on apps.id = config.app_id +left join keys on keys.client_id = config.client_id +where apps.owner_removed = false; diff --git a/internal/query/introspection_client.go b/internal/query/introspection_client.go new file mode 100644 index 00000000000..461a21d6428 --- /dev/null +++ b/internal/query/introspection_client.go @@ -0,0 +1,45 @@ +package query + +import ( + "context" + "database/sql" + _ "embed" + + "github.com/jackc/pgtype" + "github.com/zitadel/zitadel/internal/api/authz" + "github.com/zitadel/zitadel/internal/crypto" + "github.com/zitadel/zitadel/internal/database" +) + +type IntrospectionClient struct { + ClientID string + ClientSecret *crypto.CryptoValue + ProjectID string + PublicKeys database.Map[[]byte] +} + +//go:embed embed/introspection_client_by_id.sql +var introspectionClientByIDQuery string + +func (q *Queries) GetIntrospectionClientByID(ctx context.Context, clientID string, getKeys bool) (_ *IntrospectionClient, err error) { + var ( + instanceID = authz.GetInstance(ctx).InstanceID() + client = new(IntrospectionClient) + ) + + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + var publicKeys pgtype.ByteaArray + if err := row.Scan(&client.ClientID, &client.ClientSecret, &client.ProjectID, &publicKeys); err != nil { + return err + } + return publicKeys.AssignTo(&client.PublicKeys) + }, + introspectionClientByIDQuery, + instanceID, clientID, getKeys, + ) + if err != nil { + return nil, err + } + + return client, nil +}