Tim Möhlmann 016676e1dc
chore(oidc): graduate webkey to stable (#10122)
# Which Problems Are Solved

Stabilize the usage of webkeys.

# How the Problems Are Solved

- Remove all legacy signing key code from the OIDC API
- Remove the webkey feature flag from proto
- Remove the webkey feature flag from console
- Cleanup documentation

# Additional Changes

- Resolved some canonical header linter errors in OIDC
- Use the constant for `projections.lock` in the saml package.

# Additional Context

- Closes #10029
- After #10105
- After #10061
2025-06-26 19:17:45 +03:00

321 lines
8.7 KiB
Go

package oidc
import (
"context"
"fmt"
"slices"
"sync"
"sync/atomic"
"time"
"github.com/go-jose/go-jose/v4"
"github.com/jonboulle/clockwork"
"github.com/zitadel/oidc/v3/pkg/op"
"github.com/zitadel/zitadel/internal/api/authz"
http_util "github.com/zitadel/zitadel/internal/api/http"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
"github.com/zitadel/zitadel/internal/zerrors"
)
var supportedWebKeyAlgs = []string{
string(jose.EdDSA),
string(jose.RS256),
string(jose.RS384),
string(jose.RS512),
string(jose.ES256),
string(jose.ES384),
string(jose.ES512),
}
func supportedSigningAlgs() []string {
return supportedWebKeyAlgs
}
type cachedPublicKey struct {
lastUse atomic.Int64 // unix micro time.
expiry *time.Time // expiry may be nil if the key does not expire.
webKey *jose.JSONWebKey
}
func newCachedPublicKey(key *jose.JSONWebKey, expiry *time.Time, now time.Time) *cachedPublicKey {
cachedKey := &cachedPublicKey{
expiry: expiry,
webKey: key,
}
cachedKey.setLastUse(now)
return cachedKey
}
func (c *cachedPublicKey) setLastUse(now time.Time) {
c.lastUse.Store(now.UnixMicro())
}
func (c *cachedPublicKey) getLastUse() time.Time {
return time.UnixMicro(c.lastUse.Load())
}
func (c *cachedPublicKey) expired(now time.Time, validity time.Duration) bool {
return c.getLastUse().Add(validity).Before(now)
}
// publicKeyCache caches public keys in a 2-dimensional map of Instance ID and Key ID.
// When a key is not present the queryKey function is called to obtain the key
// from the database.
type publicKeyCache struct {
mtx sync.RWMutex
instanceKeys map[string]map[string]*cachedPublicKey
// queryKey returns a public web key.
// If the key does not have expiry, Time may be nil.
queryKey func(ctx context.Context, keyID string) (*jose.JSONWebKey, *time.Time, error)
clock clockwork.Clock
}
// newPublicKeyCache initializes a keySetCache starts a purging Go routine.
// The purge routine deletes all public keys that are older than maxAge.
// When the passed context is done, the purge routine will terminate.
func newPublicKeyCache(background context.Context, maxAge time.Duration, queryKey func(ctx context.Context, keyID string) (*jose.JSONWebKey, *time.Time, error)) *publicKeyCache {
k := &publicKeyCache{
instanceKeys: make(map[string]map[string]*cachedPublicKey),
queryKey: queryKey,
clock: clockwork.FromContext(background), // defaults to real clock
}
go k.purgeOnInterval(background, k.clock.NewTicker(maxAge/5), maxAge)
return k
}
func (k *publicKeyCache) purgeOnInterval(background context.Context, ticker clockwork.Ticker, maxAge time.Duration) {
defer ticker.Stop()
for {
select {
case <-background.Done():
return
case <-ticker.Chan():
}
// do the actual purging
k.mtx.Lock()
for instanceID, keys := range k.instanceKeys {
for keyID, key := range keys {
if key.expired(k.clock.Now(), maxAge) {
delete(keys, keyID)
}
}
if len(keys) == 0 {
delete(k.instanceKeys, instanceID)
}
}
k.mtx.Unlock()
}
}
func (k *publicKeyCache) setKey(instanceID, keyID string, cachedKey *cachedPublicKey) {
k.mtx.Lock()
defer k.mtx.Unlock()
if keys, ok := k.instanceKeys[instanceID]; ok {
keys[keyID] = cachedKey
return
}
k.instanceKeys[instanceID] = map[string]*cachedPublicKey{keyID: cachedKey}
}
func (k *publicKeyCache) getKey(ctx context.Context, keyID string) (_ *cachedPublicKey, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
instanceID := authz.GetInstance(ctx).InstanceID()
k.mtx.RLock()
key, ok := k.instanceKeys[instanceID][keyID]
k.mtx.RUnlock()
if ok {
key.setLastUse(k.clock.Now())
} else {
newKey, expiry, err := k.queryKey(ctx, keyID)
if err != nil {
return nil, err
}
key = newCachedPublicKey(newKey, expiry, k.clock.Now())
k.setKey(instanceID, keyID, key)
}
return key, nil
}
func (k *publicKeyCache) verifySignature(ctx context.Context, jws *jose.JSONWebSignature, checkKeyExpiry bool) (_ []byte, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() {
err = oidcError(err)
span.EndWithError(err)
}()
if len(jws.Signatures) != 1 {
return nil, zerrors.ThrowInvalidArgument(nil, "OIDC-Gid9s", "Errors.Token.Invalid")
}
key, err := k.getKey(ctx, jws.Signatures[0].Header.KeyID)
if err != nil {
return nil, err
}
if checkKeyExpiry && key.expiry != nil && key.expiry.Before(k.clock.Now()) {
return nil, zerrors.ThrowInvalidArgument(err, "QUERY-ciF4k", "Errors.Key.ExpireBeforeNow")
}
return jws.Verify(key.webKey)
}
type oidcKeySet struct {
*publicKeyCache
keyExpiryCheck bool
}
// newOidcKeySet returns an oidc.KeySet implementation around the passed cache.
// It is advised to reuse the same cache if different key set configurations are required.
func newOidcKeySet(cache *publicKeyCache, opts ...keySetOption) *oidcKeySet {
k := &oidcKeySet{
publicKeyCache: cache,
}
for _, opt := range opts {
opt(k)
}
return k
}
// VerifySignature implements the oidc.KeySet interface.
func (k *oidcKeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) (_ []byte, err error) {
return k.verifySignature(ctx, jws, k.keyExpiryCheck)
}
type keySetOption func(*oidcKeySet)
// withKeyExpiryCheck forces VerifySignature to check the expiry of the public key.
// Note that public key expiry is not part of the standard,
// but is currently established behavior of zitadel.
// We might want to remove this check in the future.
func withKeyExpiryCheck(check bool) keySetOption {
return func(k *oidcKeySet) {
k.keyExpiryCheck = check
}
}
// keySetMap is a mapping of key IDs to public key data.
type keySetMap map[string][]byte
// getKey finds the keyID and parses the public key data
// into a JSONWebKey.
func (k keySetMap) getKey(keyID string) (*jose.JSONWebKey, error) {
pubKey, err := crypto.BytesToPublicKey(k[keyID])
if err != nil {
return nil, err
}
return &jose.JSONWebKey{
Key: pubKey,
KeyID: keyID,
Use: crypto.KeyUsageSigning.String(),
}, nil
}
// VerifySignature implements the oidc.KeySet interface.
func (k keySetMap) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) {
if len(jws.Signatures) != 1 {
return nil, zerrors.ThrowInvalidArgument(nil, "OIDC-Eeth6", "Errors.Token.Invalid")
}
key, err := k.getKey(jws.Signatures[0].Header.KeyID)
if err != nil {
return nil, err
}
return jws.Verify(key)
}
const (
signingKey = "signing_key"
oidcUser = "OIDC"
retryBackoff = 500 * time.Millisecond
retryCount = 3
lockDuration = retryCount * retryBackoff * 5
gracefulPeriod = 10 * time.Minute
)
// SigningKey wraps the query.PrivateKey to implement the op.SigningKey interface
type SigningKey struct {
algorithm jose.SignatureAlgorithm
id string
key interface{}
}
func (s *SigningKey) SignatureAlgorithm() jose.SignatureAlgorithm {
return s.algorithm
}
func (s *SigningKey) Key() interface{} {
return s.key
}
func (s *SigningKey) ID() string {
return s.id
}
// KeySet implements the op.Storage interface
func (o *OPStorage) KeySet(ctx context.Context) (keys []op.Key, err error) {
panic(o.panicErr("KeySet"))
}
// SignatureAlgorithms implements the op.Storage interface
func (o *OPStorage) SignatureAlgorithms(ctx context.Context) ([]jose.SignatureAlgorithm, error) {
panic(o.panicErr("SignatureAlgorithms"))
}
// SigningKey implements the op.Storage interface
func (o *OPStorage) SigningKey(ctx context.Context) (key op.SigningKey, err error) {
panic(o.panicErr("SigningKey"))
}
func (s *Server) Keys(ctx context.Context, r *op.Request[struct{}]) (_ *op.Response, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
keyset, err := s.query.GetWebKeySet(ctx)
if err != nil {
return nil, err
}
resp := op.NewResponse(keyset)
if s.jwksCacheControlMaxAge != 0 {
resp.Header.Set(http_util.CacheControl,
fmt.Sprintf("max-age=%d, must-revalidate", int(s.jwksCacheControlMaxAge/time.Second)),
)
}
return resp, nil
}
func appendPublicKeysToWebKeySet(keyset *jose.JSONWebKeySet, pubkeys *query.PublicKeys) {
if pubkeys == nil || len(pubkeys.Keys) == 0 {
return
}
keyset.Keys = slices.Grow(keyset.Keys, len(pubkeys.Keys))
for _, key := range pubkeys.Keys {
keyset.Keys = append(keyset.Keys, jose.JSONWebKey{
Key: key.Key(),
KeyID: key.ID(),
Algorithm: key.Algorithm(),
Use: key.Use().String(),
})
}
}
func queryKeyFunc(q *query.Queries) func(ctx context.Context, keyID string) (*jose.JSONWebKey, *time.Time, error) {
return func(ctx context.Context, keyID string) (*jose.JSONWebKey, *time.Time, error) {
webKey, err := q.GetPublicWebKeyByID(ctx, keyID)
if err != nil {
return nil, nil, err
}
return webKey, nil, nil
}
}