mirror of
https://github.com/zitadel/zitadel.git
synced 2024-12-14 20:08:02 +00:00
fd0c15dd4f
# Which Problems Are Solved Use web keys, managed by the `resources/v3alpha/web_keys` API, for OIDC token signing and verification, as well as serving the public web keys on the jwks / keys endpoint. Response header on the keys endpoint now allows caching of the response. This is now "safe" to do since keys can be created ahead of time and caches have sufficient time to pickup the change before keys get enabled. # How the Problems Are Solved - The web key format is used in the `getSignerOnce` function in the `api/oidc` package. - The public key cache is changed to get and store web keys. - The jwks / keys endpoint returns the combined set of valid "legacy" public keys and all available web keys. - Cache-Control max-age default to 5 minutes and is configured in `defaults.yaml`. When the web keys feature is enabled, fallback mechanisms are in place to obtain and convert "legacy" `query.PublicKey` as web keys when needed. This allows transitioning to the feature without invalidating existing tokens. A small performance overhead may be noticed on the keys endpoint, because 2 queries need to be run sequentially. This will disappear once the feature is stable and the legacy code gets cleaned up. # Additional Changes - Extend legacy key lifetimes so that tests can be run on an existing database with more than 6 hours apart. - Discovery endpoint returns all supported algorithms when the Web Key feature is enabled. # Additional Context - Closes https://github.com/zitadel/zitadel/issues/8031 - Part of https://github.com/zitadel/zitadel/issues/7809 - After https://github.com/zitadel/oidc/pull/637 - After https://github.com/zitadel/oidc/pull/638
516 lines
14 KiB
Go
516 lines
14 KiB
Go
package oidc
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"slices"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/go-jose/go-jose/v4"
|
|
"github.com/jonboulle/clockwork"
|
|
"github.com/muhlemmer/gu"
|
|
"github.com/zitadel/logging"
|
|
"github.com/zitadel/oidc/v3/pkg/op"
|
|
|
|
"github.com/zitadel/zitadel/internal/api/authz"
|
|
http_util "github.com/zitadel/zitadel/internal/api/http"
|
|
"github.com/zitadel/zitadel/internal/crypto"
|
|
"github.com/zitadel/zitadel/internal/eventstore"
|
|
"github.com/zitadel/zitadel/internal/query"
|
|
"github.com/zitadel/zitadel/internal/repository/instance"
|
|
"github.com/zitadel/zitadel/internal/repository/keypair"
|
|
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
|
"github.com/zitadel/zitadel/internal/zerrors"
|
|
)
|
|
|
|
var supportedWebKeyAlgs = []string{
|
|
string(jose.EdDSA),
|
|
string(jose.RS256),
|
|
string(jose.RS384),
|
|
string(jose.RS512),
|
|
string(jose.ES256),
|
|
string(jose.ES384),
|
|
string(jose.ES512),
|
|
}
|
|
|
|
func supportedSigningAlgs(ctx context.Context) []string {
|
|
if authz.GetFeatures(ctx).WebKey {
|
|
return supportedWebKeyAlgs
|
|
}
|
|
return []string{string(jose.RS256)}
|
|
}
|
|
|
|
type cachedPublicKey struct {
|
|
lastUse atomic.Int64 // unix micro time.
|
|
expiry *time.Time // expiry may be nil if the key does not expire.
|
|
webKey *jose.JSONWebKey
|
|
}
|
|
|
|
func newCachedPublicKey(key *jose.JSONWebKey, expiry *time.Time, now time.Time) *cachedPublicKey {
|
|
cachedKey := &cachedPublicKey{
|
|
expiry: expiry,
|
|
webKey: key,
|
|
}
|
|
cachedKey.setLastUse(now)
|
|
return cachedKey
|
|
}
|
|
|
|
func (c *cachedPublicKey) setLastUse(now time.Time) {
|
|
c.lastUse.Store(now.UnixMicro())
|
|
}
|
|
|
|
func (c *cachedPublicKey) getLastUse() time.Time {
|
|
return time.UnixMicro(c.lastUse.Load())
|
|
}
|
|
|
|
func (c *cachedPublicKey) expired(now time.Time, validity time.Duration) bool {
|
|
return c.getLastUse().Add(validity).Before(now)
|
|
}
|
|
|
|
// publicKeyCache caches public keys in a 2-dimensional map of Instance ID and Key ID.
|
|
// When a key is not present the queryKey function is called to obtain the key
|
|
// from the database.
|
|
type publicKeyCache struct {
|
|
mtx sync.RWMutex
|
|
instanceKeys map[string]map[string]*cachedPublicKey
|
|
|
|
// queryKey returns a public web key.
|
|
// If the key does not have expiry, Time may be nil.
|
|
queryKey func(ctx context.Context, keyID string) (*jose.JSONWebKey, *time.Time, error)
|
|
clock clockwork.Clock
|
|
}
|
|
|
|
// newPublicKeyCache initializes a keySetCache starts a purging Go routine.
|
|
// The purge routine deletes all public keys that are older than maxAge.
|
|
// When the passed context is done, the purge routine will terminate.
|
|
func newPublicKeyCache(background context.Context, maxAge time.Duration, queryKey func(ctx context.Context, keyID string) (*jose.JSONWebKey, *time.Time, error)) *publicKeyCache {
|
|
k := &publicKeyCache{
|
|
instanceKeys: make(map[string]map[string]*cachedPublicKey),
|
|
queryKey: queryKey,
|
|
clock: clockwork.FromContext(background), // defaults to real clock
|
|
}
|
|
go k.purgeOnInterval(background, k.clock.NewTicker(maxAge/5), maxAge)
|
|
return k
|
|
}
|
|
|
|
func (k *publicKeyCache) purgeOnInterval(background context.Context, ticker clockwork.Ticker, maxAge time.Duration) {
|
|
defer ticker.Stop()
|
|
for {
|
|
select {
|
|
case <-background.Done():
|
|
return
|
|
case <-ticker.Chan():
|
|
}
|
|
|
|
// do the actual purging
|
|
k.mtx.Lock()
|
|
for instanceID, keys := range k.instanceKeys {
|
|
for keyID, key := range keys {
|
|
if key.expired(k.clock.Now(), maxAge) {
|
|
delete(keys, keyID)
|
|
}
|
|
}
|
|
if len(keys) == 0 {
|
|
delete(k.instanceKeys, instanceID)
|
|
}
|
|
}
|
|
k.mtx.Unlock()
|
|
}
|
|
}
|
|
|
|
func (k *publicKeyCache) setKey(instanceID, keyID string, cachedKey *cachedPublicKey) {
|
|
k.mtx.Lock()
|
|
defer k.mtx.Unlock()
|
|
|
|
if keys, ok := k.instanceKeys[instanceID]; ok {
|
|
keys[keyID] = cachedKey
|
|
return
|
|
}
|
|
k.instanceKeys[instanceID] = map[string]*cachedPublicKey{keyID: cachedKey}
|
|
}
|
|
|
|
func (k *publicKeyCache) getKey(ctx context.Context, keyID string) (_ *cachedPublicKey, err error) {
|
|
ctx, span := tracing.NewSpan(ctx)
|
|
defer func() { span.EndWithError(err) }()
|
|
|
|
instanceID := authz.GetInstance(ctx).InstanceID()
|
|
|
|
k.mtx.RLock()
|
|
key, ok := k.instanceKeys[instanceID][keyID]
|
|
k.mtx.RUnlock()
|
|
|
|
if ok {
|
|
key.setLastUse(k.clock.Now())
|
|
} else {
|
|
newKey, expiry, err := k.queryKey(ctx, keyID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
key = newCachedPublicKey(newKey, expiry, k.clock.Now())
|
|
k.setKey(instanceID, keyID, key)
|
|
}
|
|
|
|
return key, nil
|
|
}
|
|
|
|
func (k *publicKeyCache) verifySignature(ctx context.Context, jws *jose.JSONWebSignature, checkKeyExpiry bool) (_ []byte, err error) {
|
|
ctx, span := tracing.NewSpan(ctx)
|
|
defer func() {
|
|
err = oidcError(err)
|
|
span.EndWithError(err)
|
|
}()
|
|
|
|
if len(jws.Signatures) != 1 {
|
|
return nil, zerrors.ThrowInvalidArgument(nil, "OIDC-Gid9s", "Errors.Token.Invalid")
|
|
}
|
|
key, err := k.getKey(ctx, jws.Signatures[0].Header.KeyID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if checkKeyExpiry && key.expiry != nil && key.expiry.Before(k.clock.Now()) {
|
|
return nil, zerrors.ThrowInvalidArgument(err, "QUERY-ciF4k", "Errors.Key.ExpireBeforeNow")
|
|
}
|
|
return jws.Verify(key.webKey)
|
|
}
|
|
|
|
type oidcKeySet struct {
|
|
*publicKeyCache
|
|
|
|
keyExpiryCheck bool
|
|
}
|
|
|
|
// newOidcKeySet returns an oidc.KeySet implementation around the passed cache.
|
|
// It is advised to reuse the same cache if different key set configurations are required.
|
|
func newOidcKeySet(cache *publicKeyCache, opts ...keySetOption) *oidcKeySet {
|
|
k := &oidcKeySet{
|
|
publicKeyCache: cache,
|
|
}
|
|
for _, opt := range opts {
|
|
opt(k)
|
|
}
|
|
return k
|
|
}
|
|
|
|
// VerifySignature implements the oidc.KeySet interface.
|
|
func (k *oidcKeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) (_ []byte, err error) {
|
|
return k.verifySignature(ctx, jws, k.keyExpiryCheck)
|
|
}
|
|
|
|
type keySetOption func(*oidcKeySet)
|
|
|
|
// withKeyExpiryCheck forces VerifySignature to check the expiry of the public key.
|
|
// Note that public key expiry is not part of the standard,
|
|
// but is currently established behavior of zitadel.
|
|
// We might want to remove this check in the future.
|
|
func withKeyExpiryCheck(check bool) keySetOption {
|
|
return func(k *oidcKeySet) {
|
|
k.keyExpiryCheck = check
|
|
}
|
|
}
|
|
|
|
func jsonWebkey(key query.PublicKey) *jose.JSONWebKey {
|
|
return &jose.JSONWebKey{
|
|
KeyID: key.ID(),
|
|
Algorithm: key.Algorithm(),
|
|
Use: key.Use().String(),
|
|
Key: key.Key(),
|
|
}
|
|
}
|
|
|
|
// keySetMap is a mapping of key IDs to public key data.
|
|
type keySetMap map[string][]byte
|
|
|
|
// getKey finds the keyID and parses the public key data
|
|
// into a JSONWebKey.
|
|
func (k keySetMap) getKey(keyID string) (*jose.JSONWebKey, error) {
|
|
pubKey, err := crypto.BytesToPublicKey(k[keyID])
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &jose.JSONWebKey{
|
|
Key: pubKey,
|
|
KeyID: keyID,
|
|
Use: crypto.KeyUsageSigning.String(),
|
|
}, nil
|
|
}
|
|
|
|
// VerifySignature implements the oidc.KeySet interface.
|
|
func (k keySetMap) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) {
|
|
if len(jws.Signatures) != 1 {
|
|
return nil, zerrors.ThrowInvalidArgument(nil, "OIDC-Eeth6", "Errors.Token.Invalid")
|
|
}
|
|
key, err := k.getKey(jws.Signatures[0].Header.KeyID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return jws.Verify(key)
|
|
}
|
|
|
|
const (
|
|
locksTable = "projections.locks"
|
|
signingKey = "signing_key"
|
|
oidcUser = "OIDC"
|
|
|
|
retryBackoff = 500 * time.Millisecond
|
|
retryCount = 3
|
|
lockDuration = retryCount * retryBackoff * 5
|
|
gracefulPeriod = 10 * time.Minute
|
|
)
|
|
|
|
// SigningKey wraps the query.PrivateKey to implement the op.SigningKey interface
|
|
type SigningKey struct {
|
|
algorithm jose.SignatureAlgorithm
|
|
id string
|
|
key interface{}
|
|
}
|
|
|
|
func (s *SigningKey) SignatureAlgorithm() jose.SignatureAlgorithm {
|
|
return s.algorithm
|
|
}
|
|
|
|
func (s *SigningKey) Key() interface{} {
|
|
return s.key
|
|
}
|
|
|
|
func (s *SigningKey) ID() string {
|
|
return s.id
|
|
}
|
|
|
|
// PublicKey wraps the query.PublicKey to implement the op.Key interface
|
|
type PublicKey struct {
|
|
key query.PublicKey
|
|
}
|
|
|
|
func (s *PublicKey) Algorithm() jose.SignatureAlgorithm {
|
|
return jose.SignatureAlgorithm(s.key.Algorithm())
|
|
}
|
|
|
|
func (s *PublicKey) Use() string {
|
|
return s.key.Use().String()
|
|
}
|
|
|
|
func (s *PublicKey) Key() interface{} {
|
|
return s.key.Key()
|
|
}
|
|
|
|
func (s *PublicKey) ID() string {
|
|
return s.key.ID()
|
|
}
|
|
|
|
// KeySet implements the op.Storage interface
|
|
func (o *OPStorage) KeySet(ctx context.Context) (keys []op.Key, err error) {
|
|
ctx, span := tracing.NewSpan(ctx)
|
|
defer func() { span.EndWithError(err) }()
|
|
err = retry(func() error {
|
|
publicKeys, err := o.query.ActivePublicKeys(ctx, time.Now())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
keys = make([]op.Key, len(publicKeys.Keys))
|
|
for i, key := range publicKeys.Keys {
|
|
keys[i] = &PublicKey{key}
|
|
}
|
|
return nil
|
|
})
|
|
return keys, err
|
|
}
|
|
|
|
// SignatureAlgorithms implements the op.Storage interface
|
|
func (o *OPStorage) SignatureAlgorithms(ctx context.Context) ([]jose.SignatureAlgorithm, error) {
|
|
key, err := o.SigningKey(ctx)
|
|
if err != nil {
|
|
logging.WithError(err).Warn("unable to fetch signing key")
|
|
return nil, err
|
|
}
|
|
return []jose.SignatureAlgorithm{key.SignatureAlgorithm()}, nil
|
|
}
|
|
|
|
// SigningKey implements the op.Storage interface
|
|
func (o *OPStorage) SigningKey(ctx context.Context) (key op.SigningKey, err error) {
|
|
err = retry(func() error {
|
|
key, err = o.getSigningKey(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if key == nil {
|
|
return zerrors.ThrowNotFound(nil, "OIDC-ve4Qu", "Errors.Internal")
|
|
}
|
|
return nil
|
|
})
|
|
return key, err
|
|
}
|
|
|
|
func (o *OPStorage) getSigningKey(ctx context.Context) (op.SigningKey, error) {
|
|
keys, err := o.query.ActivePrivateSigningKey(ctx, time.Now().Add(gracefulPeriod))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(keys.Keys) > 0 {
|
|
return o.privateKeyToSigningKey(selectSigningKey(keys.Keys))
|
|
}
|
|
var position float64
|
|
if keys.State != nil {
|
|
position = keys.State.Position
|
|
}
|
|
return nil, o.refreshSigningKey(ctx, o.signingKeyAlgorithm, position)
|
|
}
|
|
|
|
func (o *OPStorage) refreshSigningKey(ctx context.Context, algorithm string, position float64) error {
|
|
ok, err := o.ensureIsLatestKey(ctx, position)
|
|
if err != nil || !ok {
|
|
return zerrors.ThrowInternal(err, "OIDC-ASfh3", "cannot ensure that projection is up to date")
|
|
}
|
|
err = o.lockAndGenerateSigningKeyPair(ctx, algorithm)
|
|
if err != nil {
|
|
return zerrors.ThrowInternal(err, "OIDC-ADh31", "could not create signing key")
|
|
}
|
|
return zerrors.ThrowInternal(nil, "OIDC-Df1bh", "")
|
|
}
|
|
|
|
func (o *OPStorage) ensureIsLatestKey(ctx context.Context, position float64) (bool, error) {
|
|
maxSequence, err := o.getMaxKeySequence(ctx)
|
|
if err != nil {
|
|
return false, fmt.Errorf("error retrieving new events: %w", err)
|
|
}
|
|
return position >= maxSequence, nil
|
|
}
|
|
|
|
func (o *OPStorage) privateKeyToSigningKey(key query.PrivateKey) (_ op.SigningKey, err error) {
|
|
keyData, err := crypto.Decrypt(key.Key(), o.encAlg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
privateKey, err := crypto.BytesToPrivateKey(keyData)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &SigningKey{
|
|
algorithm: jose.SignatureAlgorithm(key.Algorithm()),
|
|
key: privateKey,
|
|
id: key.ID(),
|
|
}, nil
|
|
}
|
|
|
|
func (o *OPStorage) lockAndGenerateSigningKeyPair(ctx context.Context, algorithm string) error {
|
|
logging.Info("lock and generate signing key pair")
|
|
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
defer cancel()
|
|
|
|
errs := o.locker.Lock(ctx, lockDuration, authz.GetInstance(ctx).InstanceID())
|
|
err, ok := <-errs
|
|
if err != nil || !ok {
|
|
if zerrors.IsErrorAlreadyExists(err) {
|
|
return nil
|
|
}
|
|
logging.OnError(err).Debug("initial lock failed")
|
|
return err
|
|
}
|
|
|
|
return o.command.GenerateSigningKeyPair(setOIDCCtx(ctx), algorithm)
|
|
}
|
|
|
|
func (o *OPStorage) getMaxKeySequence(ctx context.Context) (float64, error) {
|
|
return o.eventstore.LatestSequence(ctx,
|
|
eventstore.NewSearchQueryBuilder(eventstore.ColumnsMaxSequence).
|
|
ResourceOwner(authz.GetInstance(ctx).InstanceID()).
|
|
AwaitOpenTransactions().
|
|
AllowTimeTravel().
|
|
AddQuery().
|
|
AggregateTypes(keypair.AggregateType).
|
|
EventTypes(
|
|
keypair.AddedEventType,
|
|
).
|
|
Or().
|
|
AggregateTypes(instance.AggregateType).
|
|
EventTypes(instance.InstanceRemovedEventType).
|
|
Builder(),
|
|
)
|
|
}
|
|
|
|
func selectSigningKey(keys []query.PrivateKey) query.PrivateKey {
|
|
return keys[len(keys)-1]
|
|
}
|
|
|
|
func setOIDCCtx(ctx context.Context) context.Context {
|
|
return authz.SetCtxData(ctx, authz.CtxData{UserID: oidcUser, OrgID: authz.GetInstance(ctx).InstanceID()})
|
|
}
|
|
|
|
func retry(retryable func() error) (err error) {
|
|
for i := 0; i < retryCount; i++ {
|
|
err = retryable()
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
time.Sleep(retryBackoff)
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (s *Server) Keys(ctx context.Context, r *op.Request[struct{}]) (_ *op.Response, err error) {
|
|
ctx, span := tracing.NewSpan(ctx)
|
|
defer func() { span.EndWithError(err) }()
|
|
|
|
if !authz.GetFeatures(ctx).WebKey {
|
|
return s.LegacyServer.Keys(ctx, r)
|
|
}
|
|
|
|
keyset, err := s.query.GetWebKeySet(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Return legacy keys, so we do not invalidate all tokens
|
|
// once the feature flag is enabled.
|
|
legacyKeys, err := s.query.ActivePublicKeys(ctx, time.Now())
|
|
logging.OnError(err).Error("oidc server: active public keys (legacy)")
|
|
appendPublicKeysToWebKeySet(keyset, legacyKeys)
|
|
|
|
resp := op.NewResponse(keyset)
|
|
if s.jwksCacheControlMaxAge != 0 {
|
|
resp.Header.Set(http_util.CacheControl,
|
|
fmt.Sprintf("max-age=%d, must-revalidate", int(s.jwksCacheControlMaxAge/time.Second)),
|
|
)
|
|
}
|
|
|
|
return resp, nil
|
|
}
|
|
|
|
func appendPublicKeysToWebKeySet(keyset *jose.JSONWebKeySet, pubkeys *query.PublicKeys) {
|
|
if pubkeys == nil || len(pubkeys.Keys) == 0 {
|
|
return
|
|
}
|
|
keyset.Keys = slices.Grow(keyset.Keys, len(pubkeys.Keys))
|
|
|
|
for _, key := range pubkeys.Keys {
|
|
keyset.Keys = append(keyset.Keys, jose.JSONWebKey{
|
|
Key: key.Key(),
|
|
KeyID: key.ID(),
|
|
Algorithm: key.Algorithm(),
|
|
Use: key.Use().String(),
|
|
})
|
|
}
|
|
}
|
|
|
|
func queryKeyFunc(q *query.Queries) func(ctx context.Context, keyID string) (*jose.JSONWebKey, *time.Time, error) {
|
|
return func(ctx context.Context, keyID string) (*jose.JSONWebKey, *time.Time, error) {
|
|
if authz.GetFeatures(ctx).WebKey {
|
|
webKey, err := q.GetPublicWebKeyByID(ctx, keyID)
|
|
if err == nil {
|
|
return webKey, nil, nil
|
|
}
|
|
if !zerrors.IsNotFound(err) {
|
|
return nil, nil, err
|
|
}
|
|
}
|
|
|
|
pubKey, err := q.GetPublicKeyByID(ctx, keyID)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
return jsonWebkey(pubKey), gu.Ptr(pubKey.Expiry()), nil
|
|
}
|
|
}
|