fix(oidc): ignore public key expiry for ID Token hints (#7293)

* fix(oidc): ignore public key expiry for ID Token hints

This splits the key sets used for access token and ID token hints.
ID Token hints should be able to be verified by with public keys that are already expired.
However, we do not want to change this behavior for Access Tokens,
where an error for an expired public key is still returned.

The public key cache is modified to purge public keys based on last use,
instead of expiry.
The cache is shared between both verifiers.

* resolve review comments

* pin oidc 3.11
This commit is contained in:
Tim Möhlmann
2024-01-29 17:11:52 +02:00
committed by GitHub
parent 5e23ea55b2
commit df57a64ed7
12 changed files with 201 additions and 147 deletions

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/go-jose/go-jose/v3"
@@ -22,31 +23,55 @@ import (
"github.com/zitadel/zitadel/internal/zerrors"
)
// keySetCache implements oidc.KeySet for Access Token verification.
// Public Keys are cached in a 2-dimensional map of Instance ID and Key ID.
type cachedPublicKey struct {
lastUse atomic.Int64 // unix micro time.
query.PublicKey
}
func newCachedPublicKey(key query.PublicKey, now time.Time) *cachedPublicKey {
cachedKey := &cachedPublicKey{
PublicKey: key,
}
cachedKey.setLastUse(now)
return cachedKey
}
func (c *cachedPublicKey) setLastUse(now time.Time) {
c.lastUse.Store(now.UnixMicro())
}
func (c *cachedPublicKey) getLastUse() time.Time {
return time.UnixMicro(c.lastUse.Load())
}
func (c *cachedPublicKey) expired(now time.Time, validity time.Duration) bool {
return c.getLastUse().Add(validity).Before(now)
}
// publicKeyCache caches public keys in a 2-dimensional map of Instance ID and Key ID.
// When a key is not present the queryKey function is called to obtain the key
// from the database.
type keySetCache struct {
type publicKeyCache struct {
mtx sync.RWMutex
instanceKeys map[string]map[string]query.PublicKey
queryKey func(ctx context.Context, keyID string, current time.Time) (query.PublicKey, error)
instanceKeys map[string]map[string]*cachedPublicKey
queryKey func(ctx context.Context, keyID string) (query.PublicKey, error)
clock clockwork.Clock
}
// newKeySet initializes a keySetCache and starts a purging Go routine,
// which runs once every purgeInterval.
// newPublicKeyCache initializes a keySetCache starts a purging Go routine.
// The purge routine deletes all public keys that are older than maxAge.
// When the passed context is done, the purge routine will terminate.
func newKeySet(background context.Context, purgeInterval time.Duration, queryKey func(ctx context.Context, keyID string, current time.Time) (query.PublicKey, error)) *keySetCache {
k := &keySetCache{
instanceKeys: make(map[string]map[string]query.PublicKey),
func newPublicKeyCache(background context.Context, maxAge time.Duration, queryKey func(ctx context.Context, keyID string) (query.PublicKey, error)) *publicKeyCache {
k := &publicKeyCache{
instanceKeys: make(map[string]map[string]*cachedPublicKey),
queryKey: queryKey,
clock: clockwork.FromContext(background), // defaults to real clock
}
go k.purgeOnInterval(background, k.clock.NewTicker(purgeInterval))
go k.purgeOnInterval(background, k.clock.NewTicker(maxAge/5), maxAge)
return k
}
func (k *keySetCache) purgeOnInterval(background context.Context, ticker clockwork.Ticker) {
func (k *publicKeyCache) purgeOnInterval(background context.Context, ticker clockwork.Ticker, maxAge time.Duration) {
defer ticker.Stop()
for {
select {
@@ -59,7 +84,7 @@ func (k *keySetCache) purgeOnInterval(background context.Context, ticker clockwo
k.mtx.Lock()
for instanceID, keys := range k.instanceKeys {
for keyID, key := range keys {
if key.Expiry().Before(k.clock.Now()) {
if key.expired(k.clock.Now(), maxAge) {
delete(keys, keyID)
}
}
@@ -71,19 +96,18 @@ func (k *keySetCache) purgeOnInterval(background context.Context, ticker clockwo
}
}
func (k *keySetCache) setKey(instanceID, keyID string, key query.PublicKey) {
func (k *publicKeyCache) setKey(instanceID, keyID string, cachedKey *cachedPublicKey) {
k.mtx.Lock()
defer k.mtx.Unlock()
if keys, ok := k.instanceKeys[instanceID]; ok {
keys[keyID] = key
keys[keyID] = cachedKey
return
}
k.instanceKeys[instanceID] = map[string]query.PublicKey{keyID: key}
k.instanceKeys[instanceID] = map[string]*cachedPublicKey{keyID: cachedKey}
}
func (k *keySetCache) getKey(ctx context.Context, keyID string) (_ *jose.JSONWebKey, err error) {
func (k *publicKeyCache) getKey(ctx context.Context, keyID string) (_ *cachedPublicKey, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
@@ -94,22 +118,20 @@ func (k *keySetCache) getKey(ctx context.Context, keyID string) (_ *jose.JSONWeb
k.mtx.RUnlock()
if ok {
if key.Expiry().After(k.clock.Now()) {
return jsonWebkey(key), nil
key.setLastUse(k.clock.Now())
} else {
newKey, err := k.queryKey(ctx, keyID)
if err != nil {
return nil, err
}
return nil, zerrors.ThrowInvalidArgument(nil, "OIDC-Zoh9E", "Errors.Key.ExpireBeforeNow")
key = newCachedPublicKey(newKey, k.clock.Now())
k.setKey(instanceID, keyID, key)
}
key, err = k.queryKey(ctx, keyID, k.clock.Now())
if err != nil {
return nil, err
}
k.setKey(instanceID, keyID, key)
return jsonWebkey(key), nil
return key, nil
}
// VerifySignature implements the oidc.KeySet interface.
func (k *keySetCache) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) (_ []byte, err error) {
func (k *publicKeyCache) verifySignature(ctx context.Context, jws *jose.JSONWebSignature, checkKeyExpiry bool) (_ []byte, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() {
err = oidcError(err)
@@ -123,7 +145,45 @@ func (k *keySetCache) VerifySignature(ctx context.Context, jws *jose.JSONWebSign
if err != nil {
return nil, err
}
return jws.Verify(key)
if checkKeyExpiry && key.Expiry().Before(k.clock.Now()) {
return nil, zerrors.ThrowInvalidArgument(err, "QUERY-ciF4k", "Errors.Key.ExpireBeforeNow")
}
return jws.Verify(jsonWebkey(key))
}
type oidcKeySet struct {
*publicKeyCache
keyExpiryCheck bool
}
// newOidcKeySet returns an oidc.KeySet implementation around the passed cache.
// It is advised to reuse the same cache if different key set configurations are required.
func newOidcKeySet(cache *publicKeyCache, opts ...keySetOption) *oidcKeySet {
k := &oidcKeySet{
publicKeyCache: cache,
}
for _, opt := range opts {
opt(k)
}
return k
}
// VerifySignature implements the oidc.KeySet interface.
func (k *oidcKeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) (_ []byte, err error) {
return k.verifySignature(ctx, jws, k.keyExpiryCheck)
}
type keySetOption func(*oidcKeySet)
// withKeyExpiryCheck forces VerifySignature to check the expiry of the public key.
// Note that public key expiry is not part of the standard,
// but is currently established behavior of zitadel.
// We might want to remove this check in the future.
func withKeyExpiryCheck(check bool) keySetOption {
return func(k *oidcKeySet) {
k.keyExpiryCheck = check
}
}
func jsonWebkey(key query.PublicKey) *jose.JSONWebKey {