mirror of
https://github.com/zitadel/zitadel.git
synced 2025-04-07 07:44:32 +00:00

# Which Problems Are Solved [A recent performance enhancement]((https://github.com/zitadel/zitadel/pull/9497)) aimed at optimizing event store queries, specifically those involving multiple aggregate type filters, has successfully improved index utilization. While the query planner now correctly selects relevant indexes, it employs [bitmap index scans](https://www.postgresql.org/docs/current/indexes-bitmap-scans.html) to retrieve data. This approach, while beneficial in many scenarios, introduces a potential I/O bottleneck. The bitmap index scan first identifies the required database blocks and then utilizes a bitmap to access the corresponding rows from the table's heap. This subsequent "bitmap heap scan" can result in significant I/O overhead, particularly when queries return a substantial number of rows across numerous data pages. ## Impact: Under heavy load or with queries filtering for a wide range of events across multiple aggregate types, this increased I/O activity may lead to: - Increased query latency. - Elevated disk utilization. - Potential performance degradation of the event store and dependent systems. # How the Problems Are Solved To address this I/O bottleneck and further optimize query performance, the projection handler has been modified. Instead of employing multiple OR clauses for each aggregate type, the aggregate and event type filters are now combined using IN ARRAY filters. Technical Details: This change allows the PostgreSQL query planner to leverage [index-only scans](https://www.postgresql.org/docs/current/indexes-index-only-scans.html). By utilizing IN ARRAY filters, the database can efficiently retrieve the necessary data directly from the index, eliminating the need to access the table's heap. This results in: * Reduced I/O: Index-only scans significantly minimize disk I/O operations, as the database avoids reading data pages from the main table. * Improved Query Performance: By reducing I/O, query execution times are substantially improved, leading to lower latency. # Additional Changes - rollback of https://github.com/zitadel/zitadel/pull/9497 # Additional Information ## Query Plan of previous query ```sql SELECT created_at, event_type, "sequence", "position", payload, creator, "owner", instance_id, aggregate_type, aggregate_id, revision FROM eventstore.events2 WHERE instance_id = '<INSTANCE_ID>' AND ( ( instance_id = '<INSTANCE_ID>' AND "position" > <POSITION> AND aggregate_type = 'project' AND event_type = ANY(ARRAY[ 'project.application.added' ,'project.application.changed' ,'project.application.deactivated' ,'project.application.reactivated' ,'project.application.removed' ,'project.removed' ,'project.application.config.api.added' ,'project.application.config.api.changed' ,'project.application.config.api.secret.changed' ,'project.application.config.api.secret.updated' ,'project.application.config.oidc.added' ,'project.application.config.oidc.changed' ,'project.application.config.oidc.secret.changed' ,'project.application.config.oidc.secret.updated' ,'project.application.config.saml.added' ,'project.application.config.saml.changed' ]) ) OR ( instance_id = '<INSTANCE_ID>' AND "position" > <POSITION> AND aggregate_type = 'org' AND event_type = 'org.removed' ) OR ( instance_id = '<INSTANCE_ID>' AND "position" > <POSITION> AND aggregate_type = 'instance' AND event_type = 'instance.removed' ) ) AND "position" > 1741600905.3495 AND "position" < ( SELECT COALESCE(EXTRACT(EPOCH FROM min(xact_start)), EXTRACT(EPOCH FROM now())) FROM pg_stat_activity WHERE datname = current_database() AND application_name = ANY(ARRAY['zitadel_es_pusher_', 'zitadel_es_pusher', 'zitadel_es_pusher_<INSTANCE_ID>']) AND state <> 'idle' ) ORDER BY "position", in_tx_order LIMIT 200 OFFSET 1; ``` ``` Limit (cost=120.08..120.09 rows=7 width=361) (actual time=2.167..2.172 rows=0 loops=1) Output: events2.created_at, events2.event_type, events2.sequence, events2."position", events2.payload, events2.creator, events2.owner, events2.instance_id, events2.aggregate_type, events2.aggregate_id, events2.revision, events2.in_tx_order InitPlan 1 -> Aggregate (cost=2.74..2.76 rows=1 width=32) (actual time=1.813..1.815 rows=1 loops=1) Output: COALESCE(EXTRACT(epoch FROM min(s.xact_start)), EXTRACT(epoch FROM now())) -> Nested Loop (cost=0.00..2.74 rows=1 width=8) (actual time=1.803..1.805 rows=0 loops=1) Output: s.xact_start Join Filter: (d.oid = s.datid) -> Seq Scan on pg_catalog.pg_database d (cost=0.00..1.07 rows=1 width=4) (actual time=0.016..0.021 rows=1 loops=1) Output: d.oid, d.datname, d.datdba, d.encoding, d.datlocprovider, d.datistemplate, d.datallowconn, d.dathasloginevt, d.datconnlimit, d.datfrozenxid, d.datminmxid, d.dattablespace, d.datcollate, d.datctype, d.datlocale, d.daticurules, d.datcollversion, d.datacl Filter: (d.datname = current_database()) Rows Removed by Filter: 4 -> Function Scan on pg_catalog.pg_stat_get_activity s (cost=0.00..1.63 rows=3 width=16) (actual time=1.781..1.781 rows=0 loops=1) Output: s.datid, s.pid, s.usesysid, s.application_name, s.state, s.query, s.wait_event_type, s.wait_event, s.xact_start, s.query_start, s.backend_start, s.state_change, s.client_addr, s.client_hostname, s.client_port, s.backend_xid, s.backend_xmin, s.backend_type, s.ssl, s.sslversion, s.sslcipher, s.sslbits, s.ssl_client_dn, s.ssl_client_serial, s.ssl_issuer_dn, s.gss_auth, s.gss_princ, s.gss_enc, s.gss_delegation, s.leader_pid, s.query_id Function Call: pg_stat_get_activity(NULL::integer) Filter: ((s.state <> 'idle'::text) AND (s.application_name = ANY ('{zitadel_es_pusher_,zitadel_es_pusher,zitadel_es_pusher_<INSTANCE_ID>}'::text[]))) Rows Removed by Filter: 49 -> Sort (cost=117.31..117.33 rows=8 width=361) (actual time=2.167..2.168 rows=0 loops=1) Output: events2.created_at, events2.event_type, events2.sequence, events2."position", events2.payload, events2.creator, events2.owner, events2.instance_id, events2.aggregate_type, events2.aggregate_id, events2.revision, events2.in_tx_order Sort Key: events2."position", events2.in_tx_order Sort Method: quicksort Memory: 25kB -> Bitmap Heap Scan on eventstore.events2 (cost=84.92..117.19 rows=8 width=361) (actual time=2.088..2.089 rows=0 loops=1) Output: events2.created_at, events2.event_type, events2.sequence, events2."position", events2.payload, events2.creator, events2.owner, events2.instance_id, events2.aggregate_type, events2.aggregate_id, events2.revision, events2.in_tx_order Recheck Cond: (((events2.instance_id = '<INSTANCE_ID>'::text) AND (events2.aggregate_type = 'project'::text) AND (events2.event_type = ANY ('{project.application.added,project.application.changed,project.application.deactivated,project.application.reactivated,project.application.removed,project.removed,project.application.config.api.added,project.application.config.api.changed,project.application.config.api.secret.changed,project.application.config.api.secret.updated,project.application.config.oidc.added,project.application.config.oidc.changed,project.application.config.oidc.secret.changed,project.application.config.oidc.secret.updated,project.application.config.saml.added,project.application.config.saml.changed}'::text[])) AND (events2."position" > <POSITION>) AND (events2."position" > 1741600905.3495) AND (events2."position" < (InitPlan 1).col1)) OR ((events2.instance_id = '<INSTANCE_ID>'::text) AND (events2.aggregate_type = 'org'::text) AND (events2.event_type = 'org.removed'::text) AND (events2."position" > <POSITION>) AND (events2."position" > 1741600905.3495) AND (events2."position" < (InitPlan 1).col1)) OR ((events2.instance_id = '<INSTANCE_ID>'::text) AND (events2.aggregate_type = 'instance'::text) AND (events2.event_type = 'instance.removed'::text) AND (events2."position" > <POSITION>) AND (events2."position" > 1741600905.3495) AND (events2."position" < (InitPlan 1).col1))) -> BitmapOr (cost=84.88..84.88 rows=8 width=0) (actual time=2.080..2.081 rows=0 loops=1) -> Bitmap Index Scan on es_projection (cost=0.00..75.44 rows=8 width=0) (actual time=2.016..2.017 rows=0 loops=1) Index Cond: ((events2.instance_id = '<INSTANCE_ID>'::text) AND (events2.aggregate_type = 'project'::text) AND (events2.event_type = ANY ('{project.application.added,project.application.changed,project.application.deactivated,project.application.reactivated,project.application.removed,project.removed,project.application.config.api.added,project.application.config.api.changed,project.application.config.api.secret.changed,project.application.config.api.secret.updated,project.application.config.oidc.added,project.application.config.oidc.changed,project.application.config.oidc.secret.changed,project.application.config.oidc.secret.updated,project.application.config.saml.added,project.application.config.saml.changed}'::text[])) AND (events2."position" > <POSITION>) AND (events2."position" > 1741600905.3495) AND (events2."position" < (InitPlan 1).col1)) -> Bitmap Index Scan on es_projection (cost=0.00..4.71 rows=1 width=0) (actual time=0.016..0.016 rows=0 loops=1) Index Cond: ((events2.instance_id = '<INSTANCE_ID>'::text) AND (events2.aggregate_type = 'org'::text) AND (events2.event_type = 'org.removed'::text) AND (events2."position" > <POSITION>) AND (events2."position" > 1741600905.3495) AND (events2."position" < (InitPlan 1).col1)) -> Bitmap Index Scan on es_projection (cost=0.00..4.71 rows=1 width=0) (actual time=0.045..0.045 rows=0 loops=1) Index Cond: ((events2.instance_id = '<INSTANCE_ID>'::text) AND (events2.aggregate_type = 'instance'::text) AND (events2.event_type = 'instance.removed'::text) AND (events2."position" > <POSITION>) AND (events2."position" > 1741600905.3495) AND (events2."position" < (InitPlan 1).col1)) Query Identifier: 3194938266011254479 Planning Time: 1.295 ms Execution Time: 2.832 ms ``` ## Query Plan of new query ```sql SELECT created_at, event_type, "sequence", "position", payload, creator, "owner", instance_id, aggregate_type, aggregate_id, revision FROM eventstore.events2 WHERE instance_id = '<INSTANCE_ID>' AND "position" > <POSITION> AND aggregate_type = ANY(ARRAY['project', 'instance', 'org']) AND event_type = ANY(ARRAY[ 'project.application.added' ,'project.application.changed' ,'project.application.deactivated' ,'project.application.reactivated' ,'project.application.removed' ,'project.removed' ,'project.application.config.api.added' ,'project.application.config.api.changed' ,'project.application.config.api.secret.changed' ,'project.application.config.api.secret.updated' ,'project.application.config.oidc.added' ,'project.application.config.oidc.changed' ,'project.application.config.oidc.secret.changed' ,'project.application.config.oidc.secret.updated' ,'project.application.config.saml.added' ,'project.application.config.saml.changed' ,'org.removed' ,'instance.removed' ]) AND "position" < ( SELECT COALESCE(EXTRACT(EPOCH FROM min(xact_start)), EXTRACT(EPOCH FROM now())) FROM pg_stat_activity WHERE datname = current_database() AND application_name = ANY(ARRAY['zitadel_es_pusher_', 'zitadel_es_pusher', 'zitadel_es_pusher_<INSTANCE_ID>']) AND state <> 'idle' ) ORDER BY "position", in_tx_order LIMIT 200 OFFSET 1; ``` ``` Limit (cost=293.34..293.36 rows=8 width=361) (actual time=4.686..4.689 rows=0 loops=1) Output: events2.created_at, events2.event_type, events2.sequence, events2."position", events2.payload, events2.creator, events2.owner, events2.instance_id, events2.aggregate_type, events2.aggregate_id, events2.revision, events2.in_tx_order InitPlan 1 -> Aggregate (cost=2.74..2.76 rows=1 width=32) (actual time=1.717..1.719 rows=1 loops=1) Output: COALESCE(EXTRACT(epoch FROM min(s.xact_start)), EXTRACT(epoch FROM now())) -> Nested Loop (cost=0.00..2.74 rows=1 width=8) (actual time=1.658..1.659 rows=0 loops=1) Output: s.xact_start Join Filter: (d.oid = s.datid) -> Seq Scan on pg_catalog.pg_database d (cost=0.00..1.07 rows=1 width=4) (actual time=0.026..0.028 rows=1 loops=1) Output: d.oid, d.datname, d.datdba, d.encoding, d.datlocprovider, d.datistemplate, d.datallowconn, d.dathasloginevt, d.datconnlimit, d.datfrozenxid, d.datminmxid, d.dattablespace, d.datcollate, d.datctype, d.datlocale, d.daticurules, d.datcollversion, d.datacl Filter: (d.datname = current_database()) Rows Removed by Filter: 4 -> Function Scan on pg_catalog.pg_stat_get_activity s (cost=0.00..1.63 rows=3 width=16) (actual time=1.628..1.628 rows=0 loops=1) Output: s.datid, s.pid, s.usesysid, s.application_name, s.state, s.query, s.wait_event_type, s.wait_event, s.xact_start, s.query_start, s.backend_start, s.state_change, s.client_addr, s.client_hostname, s.client_port, s.backend_xid, s.backend_xmin, s.backend_type, s.ssl, s.sslversion, s.sslcipher, s.sslbits, s.ssl_client_dn, s.ssl_client_serial, s.ssl_issuer_dn, s.gss_auth, s.gss_princ, s.gss_enc, s.gss_delegation, s.leader_pid, s.query_id Function Call: pg_stat_get_activity(NULL::integer) Filter: ((s.state <> 'idle'::text) AND (s.application_name = ANY ('{zitadel_es_pusher_,zitadel_es_pusher,zitadel_es_pusher_<INSTANCE_ID>}'::text[]))) Rows Removed by Filter: 42 -> Sort (cost=290.58..290.60 rows=9 width=361) (actual time=4.685..4.685 rows=0 loops=1) Output: events2.created_at, events2.event_type, events2.sequence, events2."position", events2.payload, events2.creator, events2.owner, events2.instance_id, events2.aggregate_type, events2.aggregate_id, events2.revision, events2.in_tx_order Sort Key: events2."position", events2.in_tx_order Sort Method: quicksort Memory: 25kB -> Index Scan using es_projection on eventstore.events2 (cost=0.70..290.43 rows=9 width=361) (actual time=4.616..4.617 rows=0 loops=1) Output: events2.created_at, events2.event_type, events2.sequence, events2."position", events2.payload, events2.creator, events2.owner, events2.instance_id, events2.aggregate_type, events2.aggregate_id, events2.revision, events2.in_tx_order Index Cond: ((events2.instance_id = '<INSTANCE_ID>'::text) AND (events2.aggregate_type = ANY ('{project,instance,org}'::text[])) AND (events2.event_type = ANY ('{project.application.added,project.application.changed,project.application.deactivated,project.application.reactivated,project.application.removed,project.removed,project.application.config.api.added,project.application.config.api.changed,project.application.config.api.secret.changed,project.application.config.api.secret.updated,project.application.config.oidc.added,project.application.config.oidc.changed,project.application.config.oidc.secret.changed,project.application.config.oidc.secret.updated,project.application.config.saml.added,project.application.config.saml.changed,org.removed,instance.removed}'::text[])) AND (events2."position" > <POSITION>) AND (events2."position" < (InitPlan 1).col1)) Query Identifier: -8254550537132386499 Planning Time: 2.864 ms Execution Time: 5.414 ms ``` (cherry picked from commit e36f402e093f53b9a8ef614da2e3c77c65cb45f5)
517 lines
14 KiB
Go
517 lines
14 KiB
Go
package oidc
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"slices"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/go-jose/go-jose/v4"
|
|
"github.com/jonboulle/clockwork"
|
|
"github.com/muhlemmer/gu"
|
|
"github.com/zitadel/logging"
|
|
"github.com/zitadel/oidc/v3/pkg/op"
|
|
|
|
"github.com/zitadel/zitadel/internal/api/authz"
|
|
http_util "github.com/zitadel/zitadel/internal/api/http"
|
|
"github.com/zitadel/zitadel/internal/crypto"
|
|
"github.com/zitadel/zitadel/internal/eventstore"
|
|
"github.com/zitadel/zitadel/internal/query"
|
|
"github.com/zitadel/zitadel/internal/repository/instance"
|
|
"github.com/zitadel/zitadel/internal/repository/keypair"
|
|
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
|
"github.com/zitadel/zitadel/internal/zerrors"
|
|
)
|
|
|
|
var supportedWebKeyAlgs = []string{
|
|
string(jose.EdDSA),
|
|
string(jose.RS256),
|
|
string(jose.RS384),
|
|
string(jose.RS512),
|
|
string(jose.ES256),
|
|
string(jose.ES384),
|
|
string(jose.ES512),
|
|
}
|
|
|
|
func supportedSigningAlgs(ctx context.Context) []string {
|
|
if authz.GetFeatures(ctx).WebKey {
|
|
return supportedWebKeyAlgs
|
|
}
|
|
return []string{string(jose.RS256)}
|
|
}
|
|
|
|
type cachedPublicKey struct {
|
|
lastUse atomic.Int64 // unix micro time.
|
|
expiry *time.Time // expiry may be nil if the key does not expire.
|
|
webKey *jose.JSONWebKey
|
|
}
|
|
|
|
func newCachedPublicKey(key *jose.JSONWebKey, expiry *time.Time, now time.Time) *cachedPublicKey {
|
|
cachedKey := &cachedPublicKey{
|
|
expiry: expiry,
|
|
webKey: key,
|
|
}
|
|
cachedKey.setLastUse(now)
|
|
return cachedKey
|
|
}
|
|
|
|
func (c *cachedPublicKey) setLastUse(now time.Time) {
|
|
c.lastUse.Store(now.UnixMicro())
|
|
}
|
|
|
|
func (c *cachedPublicKey) getLastUse() time.Time {
|
|
return time.UnixMicro(c.lastUse.Load())
|
|
}
|
|
|
|
func (c *cachedPublicKey) expired(now time.Time, validity time.Duration) bool {
|
|
return c.getLastUse().Add(validity).Before(now)
|
|
}
|
|
|
|
// publicKeyCache caches public keys in a 2-dimensional map of Instance ID and Key ID.
|
|
// When a key is not present the queryKey function is called to obtain the key
|
|
// from the database.
|
|
type publicKeyCache struct {
|
|
mtx sync.RWMutex
|
|
instanceKeys map[string]map[string]*cachedPublicKey
|
|
|
|
// queryKey returns a public web key.
|
|
// If the key does not have expiry, Time may be nil.
|
|
queryKey func(ctx context.Context, keyID string) (*jose.JSONWebKey, *time.Time, error)
|
|
clock clockwork.Clock
|
|
}
|
|
|
|
// newPublicKeyCache initializes a keySetCache starts a purging Go routine.
|
|
// The purge routine deletes all public keys that are older than maxAge.
|
|
// When the passed context is done, the purge routine will terminate.
|
|
func newPublicKeyCache(background context.Context, maxAge time.Duration, queryKey func(ctx context.Context, keyID string) (*jose.JSONWebKey, *time.Time, error)) *publicKeyCache {
|
|
k := &publicKeyCache{
|
|
instanceKeys: make(map[string]map[string]*cachedPublicKey),
|
|
queryKey: queryKey,
|
|
clock: clockwork.FromContext(background), // defaults to real clock
|
|
}
|
|
go k.purgeOnInterval(background, k.clock.NewTicker(maxAge/5), maxAge)
|
|
return k
|
|
}
|
|
|
|
func (k *publicKeyCache) purgeOnInterval(background context.Context, ticker clockwork.Ticker, maxAge time.Duration) {
|
|
defer ticker.Stop()
|
|
for {
|
|
select {
|
|
case <-background.Done():
|
|
return
|
|
case <-ticker.Chan():
|
|
}
|
|
|
|
// do the actual purging
|
|
k.mtx.Lock()
|
|
for instanceID, keys := range k.instanceKeys {
|
|
for keyID, key := range keys {
|
|
if key.expired(k.clock.Now(), maxAge) {
|
|
delete(keys, keyID)
|
|
}
|
|
}
|
|
if len(keys) == 0 {
|
|
delete(k.instanceKeys, instanceID)
|
|
}
|
|
}
|
|
k.mtx.Unlock()
|
|
}
|
|
}
|
|
|
|
func (k *publicKeyCache) setKey(instanceID, keyID string, cachedKey *cachedPublicKey) {
|
|
k.mtx.Lock()
|
|
defer k.mtx.Unlock()
|
|
|
|
if keys, ok := k.instanceKeys[instanceID]; ok {
|
|
keys[keyID] = cachedKey
|
|
return
|
|
}
|
|
k.instanceKeys[instanceID] = map[string]*cachedPublicKey{keyID: cachedKey}
|
|
}
|
|
|
|
func (k *publicKeyCache) getKey(ctx context.Context, keyID string) (_ *cachedPublicKey, err error) {
|
|
ctx, span := tracing.NewSpan(ctx)
|
|
defer func() { span.EndWithError(err) }()
|
|
|
|
instanceID := authz.GetInstance(ctx).InstanceID()
|
|
|
|
k.mtx.RLock()
|
|
key, ok := k.instanceKeys[instanceID][keyID]
|
|
k.mtx.RUnlock()
|
|
|
|
if ok {
|
|
key.setLastUse(k.clock.Now())
|
|
} else {
|
|
newKey, expiry, err := k.queryKey(ctx, keyID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
key = newCachedPublicKey(newKey, expiry, k.clock.Now())
|
|
k.setKey(instanceID, keyID, key)
|
|
}
|
|
|
|
return key, nil
|
|
}
|
|
|
|
func (k *publicKeyCache) verifySignature(ctx context.Context, jws *jose.JSONWebSignature, checkKeyExpiry bool) (_ []byte, err error) {
|
|
ctx, span := tracing.NewSpan(ctx)
|
|
defer func() {
|
|
err = oidcError(err)
|
|
span.EndWithError(err)
|
|
}()
|
|
|
|
if len(jws.Signatures) != 1 {
|
|
return nil, zerrors.ThrowInvalidArgument(nil, "OIDC-Gid9s", "Errors.Token.Invalid")
|
|
}
|
|
key, err := k.getKey(ctx, jws.Signatures[0].Header.KeyID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if checkKeyExpiry && key.expiry != nil && key.expiry.Before(k.clock.Now()) {
|
|
return nil, zerrors.ThrowInvalidArgument(err, "QUERY-ciF4k", "Errors.Key.ExpireBeforeNow")
|
|
}
|
|
return jws.Verify(key.webKey)
|
|
}
|
|
|
|
type oidcKeySet struct {
|
|
*publicKeyCache
|
|
|
|
keyExpiryCheck bool
|
|
}
|
|
|
|
// newOidcKeySet returns an oidc.KeySet implementation around the passed cache.
|
|
// It is advised to reuse the same cache if different key set configurations are required.
|
|
func newOidcKeySet(cache *publicKeyCache, opts ...keySetOption) *oidcKeySet {
|
|
k := &oidcKeySet{
|
|
publicKeyCache: cache,
|
|
}
|
|
for _, opt := range opts {
|
|
opt(k)
|
|
}
|
|
return k
|
|
}
|
|
|
|
// VerifySignature implements the oidc.KeySet interface.
|
|
func (k *oidcKeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) (_ []byte, err error) {
|
|
return k.verifySignature(ctx, jws, k.keyExpiryCheck)
|
|
}
|
|
|
|
type keySetOption func(*oidcKeySet)
|
|
|
|
// withKeyExpiryCheck forces VerifySignature to check the expiry of the public key.
|
|
// Note that public key expiry is not part of the standard,
|
|
// but is currently established behavior of zitadel.
|
|
// We might want to remove this check in the future.
|
|
func withKeyExpiryCheck(check bool) keySetOption {
|
|
return func(k *oidcKeySet) {
|
|
k.keyExpiryCheck = check
|
|
}
|
|
}
|
|
|
|
func jsonWebkey(key query.PublicKey) *jose.JSONWebKey {
|
|
return &jose.JSONWebKey{
|
|
KeyID: key.ID(),
|
|
Algorithm: key.Algorithm(),
|
|
Use: key.Use().String(),
|
|
Key: key.Key(),
|
|
}
|
|
}
|
|
|
|
// keySetMap is a mapping of key IDs to public key data.
|
|
type keySetMap map[string][]byte
|
|
|
|
// getKey finds the keyID and parses the public key data
|
|
// into a JSONWebKey.
|
|
func (k keySetMap) getKey(keyID string) (*jose.JSONWebKey, error) {
|
|
pubKey, err := crypto.BytesToPublicKey(k[keyID])
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &jose.JSONWebKey{
|
|
Key: pubKey,
|
|
KeyID: keyID,
|
|
Use: crypto.KeyUsageSigning.String(),
|
|
}, nil
|
|
}
|
|
|
|
// VerifySignature implements the oidc.KeySet interface.
|
|
func (k keySetMap) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) {
|
|
if len(jws.Signatures) != 1 {
|
|
return nil, zerrors.ThrowInvalidArgument(nil, "OIDC-Eeth6", "Errors.Token.Invalid")
|
|
}
|
|
key, err := k.getKey(jws.Signatures[0].Header.KeyID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return jws.Verify(key)
|
|
}
|
|
|
|
const (
|
|
locksTable = "projections.locks"
|
|
signingKey = "signing_key"
|
|
oidcUser = "OIDC"
|
|
|
|
retryBackoff = 500 * time.Millisecond
|
|
retryCount = 3
|
|
lockDuration = retryCount * retryBackoff * 5
|
|
gracefulPeriod = 10 * time.Minute
|
|
)
|
|
|
|
// SigningKey wraps the query.PrivateKey to implement the op.SigningKey interface
|
|
type SigningKey struct {
|
|
algorithm jose.SignatureAlgorithm
|
|
id string
|
|
key interface{}
|
|
}
|
|
|
|
func (s *SigningKey) SignatureAlgorithm() jose.SignatureAlgorithm {
|
|
return s.algorithm
|
|
}
|
|
|
|
func (s *SigningKey) Key() interface{} {
|
|
return s.key
|
|
}
|
|
|
|
func (s *SigningKey) ID() string {
|
|
return s.id
|
|
}
|
|
|
|
// PublicKey wraps the query.PublicKey to implement the op.Key interface
|
|
type PublicKey struct {
|
|
key query.PublicKey
|
|
}
|
|
|
|
func (s *PublicKey) Algorithm() jose.SignatureAlgorithm {
|
|
return jose.SignatureAlgorithm(s.key.Algorithm())
|
|
}
|
|
|
|
func (s *PublicKey) Use() string {
|
|
return s.key.Use().String()
|
|
}
|
|
|
|
func (s *PublicKey) Key() interface{} {
|
|
return s.key.Key()
|
|
}
|
|
|
|
func (s *PublicKey) ID() string {
|
|
return s.key.ID()
|
|
}
|
|
|
|
// KeySet implements the op.Storage interface
|
|
func (o *OPStorage) KeySet(ctx context.Context) (keys []op.Key, err error) {
|
|
ctx, span := tracing.NewSpan(ctx)
|
|
defer func() { span.EndWithError(err) }()
|
|
err = retry(func() error {
|
|
publicKeys, err := o.query.ActivePublicKeys(ctx, time.Now())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
keys = make([]op.Key, len(publicKeys.Keys))
|
|
for i, key := range publicKeys.Keys {
|
|
keys[i] = &PublicKey{key}
|
|
}
|
|
return nil
|
|
})
|
|
return keys, err
|
|
}
|
|
|
|
// SignatureAlgorithms implements the op.Storage interface
|
|
func (o *OPStorage) SignatureAlgorithms(ctx context.Context) ([]jose.SignatureAlgorithm, error) {
|
|
key, err := o.SigningKey(ctx)
|
|
if err != nil {
|
|
logging.WithError(err).Warn("unable to fetch signing key")
|
|
return nil, err
|
|
}
|
|
return []jose.SignatureAlgorithm{key.SignatureAlgorithm()}, nil
|
|
}
|
|
|
|
// SigningKey implements the op.Storage interface
|
|
func (o *OPStorage) SigningKey(ctx context.Context) (key op.SigningKey, err error) {
|
|
err = retry(func() error {
|
|
key, err = o.getSigningKey(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if key == nil {
|
|
return zerrors.ThrowNotFound(nil, "OIDC-ve4Qu", "Errors.Internal")
|
|
}
|
|
return nil
|
|
})
|
|
return key, err
|
|
}
|
|
|
|
func (o *OPStorage) getSigningKey(ctx context.Context) (op.SigningKey, error) {
|
|
keys, err := o.query.ActivePrivateSigningKey(ctx, time.Now().Add(gracefulPeriod))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(keys.Keys) > 0 {
|
|
return PrivateKeyToSigningKey(SelectSigningKey(keys.Keys), o.encAlg)
|
|
}
|
|
var position float64
|
|
if keys.State != nil {
|
|
position = keys.State.Position
|
|
}
|
|
return nil, o.refreshSigningKey(ctx, position)
|
|
}
|
|
|
|
func (o *OPStorage) refreshSigningKey(ctx context.Context, position float64) error {
|
|
ok, err := o.ensureIsLatestKey(ctx, position)
|
|
if err != nil || !ok {
|
|
return zerrors.ThrowInternal(err, "OIDC-ASfh3", "cannot ensure that projection is up to date")
|
|
}
|
|
err = o.lockAndGenerateSigningKeyPair(ctx)
|
|
if err != nil {
|
|
return zerrors.ThrowInternal(err, "OIDC-ADh31", "could not create signing key")
|
|
}
|
|
return zerrors.ThrowInternal(nil, "OIDC-Df1bh", "")
|
|
}
|
|
|
|
func (o *OPStorage) ensureIsLatestKey(ctx context.Context, position float64) (bool, error) {
|
|
maxSequence, err := o.getMaxKeySequence(ctx)
|
|
if err != nil {
|
|
return false, fmt.Errorf("error retrieving new events: %w", err)
|
|
}
|
|
return position >= maxSequence, nil
|
|
}
|
|
|
|
func PrivateKeyToSigningKey(key query.PrivateKey, algorithm crypto.EncryptionAlgorithm) (_ op.SigningKey, err error) {
|
|
keyData, err := crypto.Decrypt(key.Key(), algorithm)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
privateKey, err := crypto.BytesToPrivateKey(keyData)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &SigningKey{
|
|
algorithm: jose.SignatureAlgorithm(key.Algorithm()),
|
|
key: privateKey,
|
|
id: key.ID(),
|
|
}, nil
|
|
}
|
|
|
|
func (o *OPStorage) lockAndGenerateSigningKeyPair(ctx context.Context) error {
|
|
logging.Info("lock and generate signing key pair")
|
|
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
defer cancel()
|
|
|
|
errs := o.locker.Lock(ctx, lockDuration, authz.GetInstance(ctx).InstanceID())
|
|
err, ok := <-errs
|
|
if err != nil || !ok {
|
|
if zerrors.IsErrorAlreadyExists(err) {
|
|
return nil
|
|
}
|
|
logging.OnError(err).Debug("initial lock failed")
|
|
return err
|
|
}
|
|
|
|
return o.command.GenerateSigningKeyPair(setOIDCCtx(ctx), "RS256")
|
|
}
|
|
|
|
func (o *OPStorage) getMaxKeySequence(ctx context.Context) (float64, error) {
|
|
return o.eventstore.LatestSequence(ctx,
|
|
eventstore.NewSearchQueryBuilder(eventstore.ColumnsMaxSequence).
|
|
ResourceOwner(authz.GetInstance(ctx).InstanceID()).
|
|
AwaitOpenTransactions().
|
|
AllowTimeTravel().
|
|
AddQuery().
|
|
AggregateTypes(
|
|
keypair.AggregateType,
|
|
instance.AggregateType,
|
|
).
|
|
EventTypes(
|
|
keypair.AddedEventType,
|
|
instance.InstanceRemovedEventType,
|
|
).
|
|
Builder(),
|
|
)
|
|
}
|
|
|
|
func SelectSigningKey(keys []query.PrivateKey) query.PrivateKey {
|
|
return keys[len(keys)-1]
|
|
}
|
|
|
|
func setOIDCCtx(ctx context.Context) context.Context {
|
|
return authz.SetCtxData(ctx, authz.CtxData{UserID: oidcUser, OrgID: authz.GetInstance(ctx).InstanceID()})
|
|
}
|
|
|
|
func retry(retryable func() error) (err error) {
|
|
for i := 0; i < retryCount; i++ {
|
|
err = retryable()
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
time.Sleep(retryBackoff)
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (s *Server) Keys(ctx context.Context, r *op.Request[struct{}]) (_ *op.Response, err error) {
|
|
ctx, span := tracing.NewSpan(ctx)
|
|
defer func() { span.EndWithError(err) }()
|
|
|
|
if !authz.GetFeatures(ctx).WebKey {
|
|
return s.LegacyServer.Keys(ctx, r)
|
|
}
|
|
|
|
keyset, err := s.query.GetWebKeySet(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Return legacy keys, so we do not invalidate all tokens
|
|
// once the feature flag is enabled.
|
|
legacyKeys, err := s.query.ActivePublicKeys(ctx, time.Now())
|
|
logging.OnError(err).Error("oidc server: active public keys (legacy)")
|
|
appendPublicKeysToWebKeySet(keyset, legacyKeys)
|
|
|
|
resp := op.NewResponse(keyset)
|
|
if s.jwksCacheControlMaxAge != 0 {
|
|
resp.Header.Set(http_util.CacheControl,
|
|
fmt.Sprintf("max-age=%d, must-revalidate", int(s.jwksCacheControlMaxAge/time.Second)),
|
|
)
|
|
}
|
|
|
|
return resp, nil
|
|
}
|
|
|
|
func appendPublicKeysToWebKeySet(keyset *jose.JSONWebKeySet, pubkeys *query.PublicKeys) {
|
|
if pubkeys == nil || len(pubkeys.Keys) == 0 {
|
|
return
|
|
}
|
|
keyset.Keys = slices.Grow(keyset.Keys, len(pubkeys.Keys))
|
|
|
|
for _, key := range pubkeys.Keys {
|
|
keyset.Keys = append(keyset.Keys, jose.JSONWebKey{
|
|
Key: key.Key(),
|
|
KeyID: key.ID(),
|
|
Algorithm: key.Algorithm(),
|
|
Use: key.Use().String(),
|
|
})
|
|
}
|
|
}
|
|
|
|
func queryKeyFunc(q *query.Queries) func(ctx context.Context, keyID string) (*jose.JSONWebKey, *time.Time, error) {
|
|
return func(ctx context.Context, keyID string) (*jose.JSONWebKey, *time.Time, error) {
|
|
if authz.GetFeatures(ctx).WebKey {
|
|
webKey, err := q.GetPublicWebKeyByID(ctx, keyID)
|
|
if err == nil {
|
|
return webKey, nil, nil
|
|
}
|
|
if !zerrors.IsNotFound(err) {
|
|
return nil, nil, err
|
|
}
|
|
}
|
|
|
|
pubKey, err := q.GetPublicKeyByID(ctx, keyID)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
return jsonWebkey(pubKey), gu.Ptr(pubKey.Expiry()), nil
|
|
}
|
|
}
|