mirror of
https://github.com/zitadel/zitadel.git
synced 2024-12-14 20:08:02 +00:00
516 lines
14 KiB
Go
516 lines
14 KiB
Go
package oidc
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"slices"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/go-jose/go-jose/v4"
|
|
"github.com/jonboulle/clockwork"
|
|
"github.com/muhlemmer/gu"
|
|
"github.com/zitadel/logging"
|
|
"github.com/zitadel/oidc/v3/pkg/op"
|
|
|
|
"github.com/zitadel/zitadel/internal/api/authz"
|
|
http_util "github.com/zitadel/zitadel/internal/api/http"
|
|
"github.com/zitadel/zitadel/internal/crypto"
|
|
"github.com/zitadel/zitadel/internal/eventstore"
|
|
"github.com/zitadel/zitadel/internal/query"
|
|
"github.com/zitadel/zitadel/internal/repository/instance"
|
|
"github.com/zitadel/zitadel/internal/repository/keypair"
|
|
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
|
"github.com/zitadel/zitadel/internal/zerrors"
|
|
)
|
|
|
|
var supportedWebKeyAlgs = []string{
|
|
string(jose.EdDSA),
|
|
string(jose.RS256),
|
|
string(jose.RS384),
|
|
string(jose.RS512),
|
|
string(jose.ES256),
|
|
string(jose.ES384),
|
|
string(jose.ES512),
|
|
}
|
|
|
|
func supportedSigningAlgs(ctx context.Context) []string {
|
|
if authz.GetFeatures(ctx).WebKey {
|
|
return supportedWebKeyAlgs
|
|
}
|
|
return []string{string(jose.RS256)}
|
|
}
|
|
|
|
type cachedPublicKey struct {
|
|
lastUse atomic.Int64 // unix micro time.
|
|
expiry *time.Time // expiry may be nil if the key does not expire.
|
|
webKey *jose.JSONWebKey
|
|
}
|
|
|
|
func newCachedPublicKey(key *jose.JSONWebKey, expiry *time.Time, now time.Time) *cachedPublicKey {
|
|
cachedKey := &cachedPublicKey{
|
|
expiry: expiry,
|
|
webKey: key,
|
|
}
|
|
cachedKey.setLastUse(now)
|
|
return cachedKey
|
|
}
|
|
|
|
func (c *cachedPublicKey) setLastUse(now time.Time) {
|
|
c.lastUse.Store(now.UnixMicro())
|
|
}
|
|
|
|
func (c *cachedPublicKey) getLastUse() time.Time {
|
|
return time.UnixMicro(c.lastUse.Load())
|
|
}
|
|
|
|
func (c *cachedPublicKey) expired(now time.Time, validity time.Duration) bool {
|
|
return c.getLastUse().Add(validity).Before(now)
|
|
}
|
|
|
|
// publicKeyCache caches public keys in a 2-dimensional map of Instance ID and Key ID.
|
|
// When a key is not present the queryKey function is called to obtain the key
|
|
// from the database.
|
|
type publicKeyCache struct {
|
|
mtx sync.RWMutex
|
|
instanceKeys map[string]map[string]*cachedPublicKey
|
|
|
|
// queryKey returns a public web key.
|
|
// If the key does not have expiry, Time may be nil.
|
|
queryKey func(ctx context.Context, keyID string) (*jose.JSONWebKey, *time.Time, error)
|
|
clock clockwork.Clock
|
|
}
|
|
|
|
// newPublicKeyCache initializes a keySetCache starts a purging Go routine.
|
|
// The purge routine deletes all public keys that are older than maxAge.
|
|
// When the passed context is done, the purge routine will terminate.
|
|
func newPublicKeyCache(background context.Context, maxAge time.Duration, queryKey func(ctx context.Context, keyID string) (*jose.JSONWebKey, *time.Time, error)) *publicKeyCache {
|
|
k := &publicKeyCache{
|
|
instanceKeys: make(map[string]map[string]*cachedPublicKey),
|
|
queryKey: queryKey,
|
|
clock: clockwork.FromContext(background), // defaults to real clock
|
|
}
|
|
go k.purgeOnInterval(background, k.clock.NewTicker(maxAge/5), maxAge)
|
|
return k
|
|
}
|
|
|
|
func (k *publicKeyCache) purgeOnInterval(background context.Context, ticker clockwork.Ticker, maxAge time.Duration) {
|
|
defer ticker.Stop()
|
|
for {
|
|
select {
|
|
case <-background.Done():
|
|
return
|
|
case <-ticker.Chan():
|
|
}
|
|
|
|
// do the actual purging
|
|
k.mtx.Lock()
|
|
for instanceID, keys := range k.instanceKeys {
|
|
for keyID, key := range keys {
|
|
if key.expired(k.clock.Now(), maxAge) {
|
|
delete(keys, keyID)
|
|
}
|
|
}
|
|
if len(keys) == 0 {
|
|
delete(k.instanceKeys, instanceID)
|
|
}
|
|
}
|
|
k.mtx.Unlock()
|
|
}
|
|
}
|
|
|
|
func (k *publicKeyCache) setKey(instanceID, keyID string, cachedKey *cachedPublicKey) {
|
|
k.mtx.Lock()
|
|
defer k.mtx.Unlock()
|
|
|
|
if keys, ok := k.instanceKeys[instanceID]; ok {
|
|
keys[keyID] = cachedKey
|
|
return
|
|
}
|
|
k.instanceKeys[instanceID] = map[string]*cachedPublicKey{keyID: cachedKey}
|
|
}
|
|
|
|
func (k *publicKeyCache) getKey(ctx context.Context, keyID string) (_ *cachedPublicKey, err error) {
|
|
ctx, span := tracing.NewSpan(ctx)
|
|
defer func() { span.EndWithError(err) }()
|
|
|
|
instanceID := authz.GetInstance(ctx).InstanceID()
|
|
|
|
k.mtx.RLock()
|
|
key, ok := k.instanceKeys[instanceID][keyID]
|
|
k.mtx.RUnlock()
|
|
|
|
if ok {
|
|
key.setLastUse(k.clock.Now())
|
|
} else {
|
|
newKey, expiry, err := k.queryKey(ctx, keyID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
key = newCachedPublicKey(newKey, expiry, k.clock.Now())
|
|
k.setKey(instanceID, keyID, key)
|
|
}
|
|
|
|
return key, nil
|
|
}
|
|
|
|
func (k *publicKeyCache) verifySignature(ctx context.Context, jws *jose.JSONWebSignature, checkKeyExpiry bool) (_ []byte, err error) {
|
|
ctx, span := tracing.NewSpan(ctx)
|
|
defer func() {
|
|
err = oidcError(err)
|
|
span.EndWithError(err)
|
|
}()
|
|
|
|
if len(jws.Signatures) != 1 {
|
|
return nil, zerrors.ThrowInvalidArgument(nil, "OIDC-Gid9s", "Errors.Token.Invalid")
|
|
}
|
|
key, err := k.getKey(ctx, jws.Signatures[0].Header.KeyID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if checkKeyExpiry && key.expiry != nil && key.expiry.Before(k.clock.Now()) {
|
|
return nil, zerrors.ThrowInvalidArgument(err, "QUERY-ciF4k", "Errors.Key.ExpireBeforeNow")
|
|
}
|
|
return jws.Verify(key.webKey)
|
|
}
|
|
|
|
type oidcKeySet struct {
|
|
*publicKeyCache
|
|
|
|
keyExpiryCheck bool
|
|
}
|
|
|
|
// newOidcKeySet returns an oidc.KeySet implementation around the passed cache.
|
|
// It is advised to reuse the same cache if different key set configurations are required.
|
|
func newOidcKeySet(cache *publicKeyCache, opts ...keySetOption) *oidcKeySet {
|
|
k := &oidcKeySet{
|
|
publicKeyCache: cache,
|
|
}
|
|
for _, opt := range opts {
|
|
opt(k)
|
|
}
|
|
return k
|
|
}
|
|
|
|
// VerifySignature implements the oidc.KeySet interface.
|
|
func (k *oidcKeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) (_ []byte, err error) {
|
|
return k.verifySignature(ctx, jws, k.keyExpiryCheck)
|
|
}
|
|
|
|
type keySetOption func(*oidcKeySet)
|
|
|
|
// withKeyExpiryCheck forces VerifySignature to check the expiry of the public key.
|
|
// Note that public key expiry is not part of the standard,
|
|
// but is currently established behavior of zitadel.
|
|
// We might want to remove this check in the future.
|
|
func withKeyExpiryCheck(check bool) keySetOption {
|
|
return func(k *oidcKeySet) {
|
|
k.keyExpiryCheck = check
|
|
}
|
|
}
|
|
|
|
func jsonWebkey(key query.PublicKey) *jose.JSONWebKey {
|
|
return &jose.JSONWebKey{
|
|
KeyID: key.ID(),
|
|
Algorithm: key.Algorithm(),
|
|
Use: key.Use().String(),
|
|
Key: key.Key(),
|
|
}
|
|
}
|
|
|
|
// keySetMap is a mapping of key IDs to public key data.
|
|
type keySetMap map[string][]byte
|
|
|
|
// getKey finds the keyID and parses the public key data
|
|
// into a JSONWebKey.
|
|
func (k keySetMap) getKey(keyID string) (*jose.JSONWebKey, error) {
|
|
pubKey, err := crypto.BytesToPublicKey(k[keyID])
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &jose.JSONWebKey{
|
|
Key: pubKey,
|
|
KeyID: keyID,
|
|
Use: crypto.KeyUsageSigning.String(),
|
|
}, nil
|
|
}
|
|
|
|
// VerifySignature implements the oidc.KeySet interface.
|
|
func (k keySetMap) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) {
|
|
if len(jws.Signatures) != 1 {
|
|
return nil, zerrors.ThrowInvalidArgument(nil, "OIDC-Eeth6", "Errors.Token.Invalid")
|
|
}
|
|
key, err := k.getKey(jws.Signatures[0].Header.KeyID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return jws.Verify(key)
|
|
}
|
|
|
|
const (
|
|
locksTable = "projections.locks"
|
|
signingKey = "signing_key"
|
|
oidcUser = "OIDC"
|
|
|
|
retryBackoff = 500 * time.Millisecond
|
|
retryCount = 3
|
|
lockDuration = retryCount * retryBackoff * 5
|
|
gracefulPeriod = 10 * time.Minute
|
|
)
|
|
|
|
// SigningKey wraps the query.PrivateKey to implement the op.SigningKey interface
|
|
type SigningKey struct {
|
|
algorithm jose.SignatureAlgorithm
|
|
id string
|
|
key interface{}
|
|
}
|
|
|
|
func (s *SigningKey) SignatureAlgorithm() jose.SignatureAlgorithm {
|
|
return s.algorithm
|
|
}
|
|
|
|
func (s *SigningKey) Key() interface{} {
|
|
return s.key
|
|
}
|
|
|
|
func (s *SigningKey) ID() string {
|
|
return s.id
|
|
}
|
|
|
|
// PublicKey wraps the query.PublicKey to implement the op.Key interface
|
|
type PublicKey struct {
|
|
key query.PublicKey
|
|
}
|
|
|
|
func (s *PublicKey) Algorithm() jose.SignatureAlgorithm {
|
|
return jose.SignatureAlgorithm(s.key.Algorithm())
|
|
}
|
|
|
|
func (s *PublicKey) Use() string {
|
|
return s.key.Use().String()
|
|
}
|
|
|
|
func (s *PublicKey) Key() interface{} {
|
|
return s.key.Key()
|
|
}
|
|
|
|
func (s *PublicKey) ID() string {
|
|
return s.key.ID()
|
|
}
|
|
|
|
// KeySet implements the op.Storage interface
|
|
func (o *OPStorage) KeySet(ctx context.Context) (keys []op.Key, err error) {
|
|
ctx, span := tracing.NewSpan(ctx)
|
|
defer func() { span.EndWithError(err) }()
|
|
err = retry(func() error {
|
|
publicKeys, err := o.query.ActivePublicKeys(ctx, time.Now())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
keys = make([]op.Key, len(publicKeys.Keys))
|
|
for i, key := range publicKeys.Keys {
|
|
keys[i] = &PublicKey{key}
|
|
}
|
|
return nil
|
|
})
|
|
return keys, err
|
|
}
|
|
|
|
// SignatureAlgorithms implements the op.Storage interface
|
|
func (o *OPStorage) SignatureAlgorithms(ctx context.Context) ([]jose.SignatureAlgorithm, error) {
|
|
key, err := o.SigningKey(ctx)
|
|
if err != nil {
|
|
logging.WithError(err).Warn("unable to fetch signing key")
|
|
return nil, err
|
|
}
|
|
return []jose.SignatureAlgorithm{key.SignatureAlgorithm()}, nil
|
|
}
|
|
|
|
// SigningKey implements the op.Storage interface
|
|
func (o *OPStorage) SigningKey(ctx context.Context) (key op.SigningKey, err error) {
|
|
err = retry(func() error {
|
|
key, err = o.getSigningKey(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if key == nil {
|
|
return zerrors.ThrowNotFound(nil, "OIDC-ve4Qu", "Errors.Internal")
|
|
}
|
|
return nil
|
|
})
|
|
return key, err
|
|
}
|
|
|
|
func (o *OPStorage) getSigningKey(ctx context.Context) (op.SigningKey, error) {
|
|
keys, err := o.query.ActivePrivateSigningKey(ctx, time.Now().Add(gracefulPeriod))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(keys.Keys) > 0 {
|
|
return o.privateKeyToSigningKey(selectSigningKey(keys.Keys))
|
|
}
|
|
var position float64
|
|
if keys.State != nil {
|
|
position = keys.State.Position
|
|
}
|
|
return nil, o.refreshSigningKey(ctx, o.signingKeyAlgorithm, position)
|
|
}
|
|
|
|
func (o *OPStorage) refreshSigningKey(ctx context.Context, algorithm string, position float64) error {
|
|
ok, err := o.ensureIsLatestKey(ctx, position)
|
|
if err != nil || !ok {
|
|
return zerrors.ThrowInternal(err, "OIDC-ASfh3", "cannot ensure that projection is up to date")
|
|
}
|
|
err = o.lockAndGenerateSigningKeyPair(ctx, algorithm)
|
|
if err != nil {
|
|
return zerrors.ThrowInternal(err, "OIDC-ADh31", "could not create signing key")
|
|
}
|
|
return zerrors.ThrowInternal(nil, "OIDC-Df1bh", "")
|
|
}
|
|
|
|
func (o *OPStorage) ensureIsLatestKey(ctx context.Context, position float64) (bool, error) {
|
|
maxSequence, err := o.getMaxKeySequence(ctx)
|
|
if err != nil {
|
|
return false, fmt.Errorf("error retrieving new events: %w", err)
|
|
}
|
|
return position >= maxSequence, nil
|
|
}
|
|
|
|
func (o *OPStorage) privateKeyToSigningKey(key query.PrivateKey) (_ op.SigningKey, err error) {
|
|
keyData, err := crypto.Decrypt(key.Key(), o.encAlg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
privateKey, err := crypto.BytesToPrivateKey(keyData)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &SigningKey{
|
|
algorithm: jose.SignatureAlgorithm(key.Algorithm()),
|
|
key: privateKey,
|
|
id: key.ID(),
|
|
}, nil
|
|
}
|
|
|
|
func (o *OPStorage) lockAndGenerateSigningKeyPair(ctx context.Context, algorithm string) error {
|
|
logging.Info("lock and generate signing key pair")
|
|
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
defer cancel()
|
|
|
|
errs := o.locker.Lock(ctx, lockDuration, authz.GetInstance(ctx).InstanceID())
|
|
err, ok := <-errs
|
|
if err != nil || !ok {
|
|
if zerrors.IsErrorAlreadyExists(err) {
|
|
return nil
|
|
}
|
|
logging.OnError(err).Debug("initial lock failed")
|
|
return err
|
|
}
|
|
|
|
return o.command.GenerateSigningKeyPair(setOIDCCtx(ctx), algorithm)
|
|
}
|
|
|
|
func (o *OPStorage) getMaxKeySequence(ctx context.Context) (float64, error) {
|
|
return o.eventstore.LatestSequence(ctx,
|
|
eventstore.NewSearchQueryBuilder(eventstore.ColumnsMaxSequence).
|
|
ResourceOwner(authz.GetInstance(ctx).InstanceID()).
|
|
AwaitOpenTransactions().
|
|
AllowTimeTravel().
|
|
AddQuery().
|
|
AggregateTypes(keypair.AggregateType).
|
|
EventTypes(
|
|
keypair.AddedEventType,
|
|
).
|
|
Or().
|
|
AggregateTypes(instance.AggregateType).
|
|
EventTypes(instance.InstanceRemovedEventType).
|
|
Builder(),
|
|
)
|
|
}
|
|
|
|
func selectSigningKey(keys []query.PrivateKey) query.PrivateKey {
|
|
return keys[len(keys)-1]
|
|
}
|
|
|
|
func setOIDCCtx(ctx context.Context) context.Context {
|
|
return authz.SetCtxData(ctx, authz.CtxData{UserID: oidcUser, OrgID: authz.GetInstance(ctx).InstanceID()})
|
|
}
|
|
|
|
func retry(retryable func() error) (err error) {
|
|
for i := 0; i < retryCount; i++ {
|
|
err = retryable()
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
time.Sleep(retryBackoff)
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (s *Server) Keys(ctx context.Context, r *op.Request[struct{}]) (_ *op.Response, err error) {
|
|
ctx, span := tracing.NewSpan(ctx)
|
|
defer func() { span.EndWithError(err) }()
|
|
|
|
if !authz.GetFeatures(ctx).WebKey {
|
|
return s.LegacyServer.Keys(ctx, r)
|
|
}
|
|
|
|
keyset, err := s.query.GetWebKeySet(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Return legacy keys, so we do not invalidate all tokens
|
|
// once the feature flag is enabled.
|
|
legacyKeys, err := s.query.ActivePublicKeys(ctx, time.Now())
|
|
logging.OnError(err).Error("oidc server: active public keys (legacy)")
|
|
appendPublicKeysToWebKeySet(keyset, legacyKeys)
|
|
|
|
resp := op.NewResponse(keyset)
|
|
if s.jwksCacheControlMaxAge != 0 {
|
|
resp.Header.Set(http_util.CacheControl,
|
|
fmt.Sprintf("max-age=%d, must-revalidate", int(s.jwksCacheControlMaxAge/time.Second)),
|
|
)
|
|
}
|
|
|
|
return resp, nil
|
|
}
|
|
|
|
func appendPublicKeysToWebKeySet(keyset *jose.JSONWebKeySet, pubkeys *query.PublicKeys) {
|
|
if pubkeys == nil || len(pubkeys.Keys) == 0 {
|
|
return
|
|
}
|
|
keyset.Keys = slices.Grow(keyset.Keys, len(pubkeys.Keys))
|
|
|
|
for _, key := range pubkeys.Keys {
|
|
keyset.Keys = append(keyset.Keys, jose.JSONWebKey{
|
|
Key: key.Key(),
|
|
KeyID: key.ID(),
|
|
Algorithm: key.Algorithm(),
|
|
Use: key.Use().String(),
|
|
})
|
|
}
|
|
}
|
|
|
|
func queryKeyFunc(q *query.Queries) func(ctx context.Context, keyID string) (*jose.JSONWebKey, *time.Time, error) {
|
|
return func(ctx context.Context, keyID string) (*jose.JSONWebKey, *time.Time, error) {
|
|
if authz.GetFeatures(ctx).WebKey {
|
|
webKey, err := q.GetPublicWebKeyByID(ctx, keyID)
|
|
if err == nil {
|
|
return webKey, nil, nil
|
|
}
|
|
if !zerrors.IsNotFound(err) {
|
|
return nil, nil, err
|
|
}
|
|
}
|
|
|
|
pubKey, err := q.GetPublicKeyByID(ctx, keyID)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
return jsonWebkey(pubKey), gu.Ptr(pubKey.Expiry()), nil
|
|
}
|
|
}
|