mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-12 01:37:31 +00:00
fix: improve key rotation (#1328)
* fix: improve key rotation * update oidc pkg version
This commit is contained in:
@@ -2,27 +2,58 @@ package eventstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/caos/logging"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
"github.com/caos/zitadel/internal/auth/repository/eventsourcing/view"
|
||||
"github.com/caos/zitadel/internal/crypto"
|
||||
"github.com/caos/zitadel/internal/errors"
|
||||
"github.com/caos/zitadel/internal/eventstore/spooler"
|
||||
"github.com/caos/zitadel/internal/eventstore/v2"
|
||||
"github.com/caos/zitadel/internal/id"
|
||||
"github.com/caos/zitadel/internal/key/model"
|
||||
key_view "github.com/caos/zitadel/internal/key/repository/view"
|
||||
"github.com/caos/zitadel/internal/v2/command"
|
||||
)
|
||||
|
||||
type KeyRepository struct {
|
||||
View *view.View
|
||||
SigningKeyRotation time.Duration
|
||||
Commands *command.CommandSide
|
||||
Eventstore *eventstore.Eventstore
|
||||
View *view.View
|
||||
SigningKeyRotationCheck time.Duration
|
||||
SigningKeyGracefulPeriod time.Duration
|
||||
KeyAlgorithm crypto.EncryptionAlgorithm
|
||||
KeyChan <-chan *model.KeyView
|
||||
Locker spooler.Locker
|
||||
lockID string
|
||||
currentKeyID string
|
||||
currentKeyExpiration time.Time
|
||||
}
|
||||
|
||||
func (k *KeyRepository) GetSigningKey(ctx context.Context, keyCh chan<- jose.SigningKey, errCh chan<- error, renewTimer <-chan time.Time) {
|
||||
const (
|
||||
signingKey = "signing_key"
|
||||
)
|
||||
|
||||
func (k *KeyRepository) GetSigningKey(ctx context.Context, keyCh chan<- jose.SigningKey, algorithm string) {
|
||||
renewTimer := time.After(0)
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case key := <-k.KeyChan:
|
||||
refreshed, err := k.refreshSigningKey(ctx, key, keyCh, algorithm)
|
||||
logging.Log("KEY-asd5g").OnError(err).Error("could not refresh signing key on key channel push")
|
||||
k.setRenewTimer(renewTimer, refreshed)
|
||||
case <-renewTimer:
|
||||
k.refreshSigningKey(keyCh, errCh)
|
||||
renewTimer = time.After(k.SigningKeyRotation)
|
||||
key, err := k.latestSigningKey()
|
||||
logging.Log("KEY-DAfh4").OnError(err).Error("could not check for latest signing key")
|
||||
refreshed, err := k.refreshSigningKey(ctx, key, keyCh, algorithm)
|
||||
logging.Log("KEY-DAfh4").OnError(err).Error("could not refresh signing key when ensuring key")
|
||||
k.setRenewTimer(renewTimer, refreshed)
|
||||
}
|
||||
}
|
||||
}()
|
||||
@@ -40,17 +71,98 @@ func (k *KeyRepository) GetKeySet(ctx context.Context) (*jose.JSONWebKeySet, err
|
||||
return &jose.JSONWebKeySet{Keys: webKeys}, nil
|
||||
}
|
||||
|
||||
func (k *KeyRepository) refreshSigningKey(keyCh chan<- jose.SigningKey, errCh chan<- error) {
|
||||
key, err := k.View.GetSigningKey()
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
func (k *KeyRepository) setRenewTimer(timer <-chan time.Time, refreshed bool) {
|
||||
duration := k.SigningKeyRotationCheck
|
||||
if refreshed {
|
||||
duration = k.currentKeyExpiration.Sub(time.Now().Add(k.SigningKeyGracefulPeriod + k.SigningKeyRotationCheck*2))
|
||||
}
|
||||
timer = time.After(duration)
|
||||
}
|
||||
|
||||
func (k *KeyRepository) latestSigningKey() (shortRefresh *model.KeyView, err error) {
|
||||
key, errView := k.View.GetActivePrivateKeyForSigning(time.Now().UTC().Add(k.SigningKeyGracefulPeriod))
|
||||
if errView != nil && !errors.IsNotFound(errView) {
|
||||
logging.Log("EVENT-GEd4h").WithError(errView).Warn("could not get signing key")
|
||||
return nil, errView
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func (k *KeyRepository) ensureIsLatestKey(ctx context.Context) (bool, error) {
|
||||
sequence, err := k.View.GetLatestKeySequence()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
events, err := k.getKeyEvents(ctx, sequence.CurrentSequence)
|
||||
if err != nil {
|
||||
logging.Log("EVENT-der5g").WithError(err).Warn("error retrieving new events")
|
||||
return false, err
|
||||
}
|
||||
if len(events) > 0 {
|
||||
logging.Log("EVENT-GBD23").Warn("view not up to date, retrying later")
|
||||
return false, nil
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (k *KeyRepository) refreshSigningKey(ctx context.Context, key *model.KeyView, keyCh chan<- jose.SigningKey, algorithm string) (refreshed bool, err error) {
|
||||
if key == nil {
|
||||
if k.currentKeyExpiration.Before(time.Now().UTC()) {
|
||||
keyCh <- jose.SigningKey{}
|
||||
}
|
||||
if ok, err := k.ensureIsLatestKey(ctx); !ok && err == nil {
|
||||
return false, err
|
||||
}
|
||||
err = k.lockAndGenerateSigningKeyPair(ctx, algorithm)
|
||||
logging.Log("EVENT-B4d21").OnError(err).Warn("could not create signing key")
|
||||
return false, err
|
||||
}
|
||||
if k.currentKeyID == key.ID {
|
||||
return false, nil
|
||||
}
|
||||
if ok, err := k.ensureIsLatestKey(ctx); !ok && err == nil {
|
||||
return false, err
|
||||
}
|
||||
signingKey, err := model.SigningKeyFromKeyView(key, k.KeyAlgorithm)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
k.currentKeyID = signingKey.ID
|
||||
k.currentKeyExpiration = key.Expiry
|
||||
keyCh <- jose.SigningKey{
|
||||
Algorithm: jose.SignatureAlgorithm(key.Algorithm),
|
||||
Algorithm: jose.SignatureAlgorithm(signingKey.Algorithm),
|
||||
Key: jose.JSONWebKey{
|
||||
KeyID: key.ID,
|
||||
Key: key.Key,
|
||||
KeyID: signingKey.ID,
|
||||
Key: signingKey.Key,
|
||||
},
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (k *KeyRepository) lockAndGenerateSigningKeyPair(ctx context.Context, algorithm string) error {
|
||||
err := k.Locker.Renew(k.lockerID(), signingKey, k.SigningKeyRotationCheck*2)
|
||||
if err != nil {
|
||||
if errors.IsErrorAlreadyExists(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
return k.Commands.GenerateSigningKeyPair(ctx, algorithm)
|
||||
}
|
||||
|
||||
func (k *KeyRepository) lockerID() string {
|
||||
if k.lockID == "" {
|
||||
var err error
|
||||
k.lockID, err = os.Hostname()
|
||||
if err != nil || k.lockID == "" {
|
||||
k.lockID, err = id.SonyFlakeGenerator.Next()
|
||||
logging.Log("EVENT-bsdf6").OnError(err).Panic("unable to generate lockID")
|
||||
}
|
||||
context.TODO()
|
||||
}
|
||||
return k.lockID
|
||||
}
|
||||
|
||||
func (k *KeyRepository) getKeyEvents(ctx context.Context, sequence uint64) ([]eventstore.EventReader, error) {
|
||||
return k.Eventstore.FilterEvents(ctx, key_view.KeyPairQuery(sequence))
|
||||
}
|
||||
|
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/caos/zitadel/internal/config/types"
|
||||
"github.com/caos/zitadel/internal/eventstore"
|
||||
"github.com/caos/zitadel/internal/eventstore/query"
|
||||
key_model "github.com/caos/zitadel/internal/key/model"
|
||||
)
|
||||
|
||||
type Configs map[string]*Config
|
||||
@@ -29,7 +30,7 @@ func (h *handler) Eventstore() eventstore.Eventstore {
|
||||
return h.es
|
||||
}
|
||||
|
||||
func Register(configs Configs, bulkLimit, errorCount uint64, view *view.View, es eventstore.Eventstore, systemDefaults sd.SystemDefaults) []query.Handler {
|
||||
func Register(configs Configs, bulkLimit, errorCount uint64, view *view.View, es eventstore.Eventstore, systemDefaults sd.SystemDefaults, keyChan chan<- *key_model.KeyView) []query.Handler {
|
||||
return []query.Handler{
|
||||
newUser(
|
||||
handler{view, bulkLimit, configs.cycleDuration("User"), errorCount, es},
|
||||
@@ -41,7 +42,8 @@ func Register(configs Configs, bulkLimit, errorCount uint64, view *view.View, es
|
||||
newToken(
|
||||
handler{view, bulkLimit, configs.cycleDuration("Token"), errorCount, es}),
|
||||
newKey(
|
||||
handler{view, bulkLimit, configs.cycleDuration("Key"), errorCount, es}),
|
||||
handler{view, bulkLimit, configs.cycleDuration("Key"), errorCount, es},
|
||||
keyChan),
|
||||
newApplication(handler{view, bulkLimit, configs.cycleDuration("Application"), errorCount, es}),
|
||||
newOrg(
|
||||
handler{view, bulkLimit, configs.cycleDuration("Org"), errorCount, es}),
|
||||
|
@@ -4,10 +4,12 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/caos/logging"
|
||||
|
||||
"github.com/caos/zitadel/internal/eventstore"
|
||||
"github.com/caos/zitadel/internal/eventstore/models"
|
||||
"github.com/caos/zitadel/internal/eventstore/query"
|
||||
"github.com/caos/zitadel/internal/eventstore/spooler"
|
||||
"github.com/caos/zitadel/internal/key/model"
|
||||
"github.com/caos/zitadel/internal/key/repository/eventsourcing"
|
||||
es_model "github.com/caos/zitadel/internal/key/repository/eventsourcing/model"
|
||||
view_model "github.com/caos/zitadel/internal/key/repository/view/model"
|
||||
@@ -20,11 +22,13 @@ const (
|
||||
type Key struct {
|
||||
handler
|
||||
subscription *eventstore.Subscription
|
||||
keyChan chan<- *model.KeyView
|
||||
}
|
||||
|
||||
func newKey(handler handler) *Key {
|
||||
func newKey(handler handler, keyChan chan<- *model.KeyView) *Key {
|
||||
h := &Key{
|
||||
handler: handler,
|
||||
keyChan: keyChan,
|
||||
}
|
||||
|
||||
h.subscribe()
|
||||
@@ -75,7 +79,12 @@ func (k *Key) Reduce(event *models.Event) error {
|
||||
if privateKey.Expiry.Before(time.Now()) && publicKey.Expiry.Before(time.Now()) {
|
||||
return k.view.ProcessedKeySequence(event)
|
||||
}
|
||||
return k.view.PutKeys(privateKey, publicKey, event)
|
||||
err = k.view.PutKeys(privateKey, publicKey, event)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
k.keyChan <- view_model.KeyViewToModel(privateKey)
|
||||
return nil
|
||||
default:
|
||||
return k.view.ProcessedKeySequence(event)
|
||||
}
|
||||
|
@@ -15,6 +15,7 @@ import (
|
||||
es_int "github.com/caos/zitadel/internal/eventstore"
|
||||
es_spol "github.com/caos/zitadel/internal/eventstore/spooler"
|
||||
"github.com/caos/zitadel/internal/id"
|
||||
key_model "github.com/caos/zitadel/internal/key/model"
|
||||
"github.com/caos/zitadel/internal/v2/command"
|
||||
"github.com/caos/zitadel/internal/v2/query"
|
||||
)
|
||||
@@ -75,7 +76,9 @@ func Start(conf Config, authZ authz.Config, systemDefaults sd.SystemDefaults, co
|
||||
return nil, err
|
||||
}
|
||||
|
||||
spool := spooler.StartSpooler(conf.Spooler, es, view, sqlClient, systemDefaults)
|
||||
keyChan := make(chan *key_model.KeyView)
|
||||
spool := spooler.StartSpooler(conf.Spooler, es, view, sqlClient, systemDefaults, keyChan)
|
||||
locker := spooler.NewLocker(sqlClient)
|
||||
|
||||
userRepo := eventstore.UserRepo{
|
||||
SearchLimit: conf.SearchLimit,
|
||||
@@ -108,12 +111,15 @@ func Start(conf Config, authZ authz.Config, systemDefaults sd.SystemDefaults, co
|
||||
IAMID: systemDefaults.IamID,
|
||||
},
|
||||
eventstore.TokenRepo{
|
||||
Eventstore: es,
|
||||
View: view,
|
||||
},
|
||||
eventstore.KeyRepository{
|
||||
View: view,
|
||||
SigningKeyRotation: systemDefaults.KeyConfig.SigningKeyRotation.Duration,
|
||||
View: view,
|
||||
SigningKeyRotationCheck: systemDefaults.KeyConfig.SigningKeyRotationCheck.Duration,
|
||||
SigningKeyGracefulPeriod: systemDefaults.KeyConfig.SigningKeyGracefulPeriod.Duration,
|
||||
KeyAlgorithm: keyAlgorithm,
|
||||
Locker: locker,
|
||||
KeyChan: keyChan,
|
||||
},
|
||||
eventstore.ApplicationRepo{
|
||||
Commands: command,
|
||||
|
@@ -15,6 +15,10 @@ type locker struct {
|
||||
dbClient *sql.DB
|
||||
}
|
||||
|
||||
func NewLocker(client *sql.DB) *locker {
|
||||
return &locker{dbClient: client}
|
||||
}
|
||||
|
||||
func (l *locker) Renew(lockerID, viewModel string, waitTime time.Duration) error {
|
||||
return es_locker.Renew(l.dbClient, lockTable, lockerID, viewModel, waitTime)
|
||||
}
|
||||
|
@@ -3,13 +3,12 @@ package spooler
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
sd "github.com/caos/zitadel/internal/config/systemdefaults"
|
||||
|
||||
"github.com/caos/zitadel/internal/auth/repository/eventsourcing/handler"
|
||||
"github.com/caos/zitadel/internal/auth/repository/eventsourcing/view"
|
||||
|
||||
sd "github.com/caos/zitadel/internal/config/systemdefaults"
|
||||
"github.com/caos/zitadel/internal/eventstore"
|
||||
"github.com/caos/zitadel/internal/eventstore/spooler"
|
||||
key_model "github.com/caos/zitadel/internal/key/model"
|
||||
)
|
||||
|
||||
type SpoolerConfig struct {
|
||||
@@ -19,12 +18,12 @@ type SpoolerConfig struct {
|
||||
Handlers handler.Configs
|
||||
}
|
||||
|
||||
func StartSpooler(c SpoolerConfig, es eventstore.Eventstore, view *view.View, client *sql.DB, systemDefaults sd.SystemDefaults) *spooler.Spooler {
|
||||
func StartSpooler(c SpoolerConfig, es eventstore.Eventstore, view *view.View, client *sql.DB, systemDefaults sd.SystemDefaults, keyChan chan<- *key_model.KeyView) *spooler.Spooler {
|
||||
spoolerConfig := spooler.Config{
|
||||
Eventstore: es,
|
||||
Locker: &locker{dbClient: client},
|
||||
ConcurrentWorkers: c.ConcurrentWorkers,
|
||||
ViewHandlers: handler.Register(c.Handlers, c.BulkLimit, c.FailureCountUntilSkip, view, es, systemDefaults),
|
||||
ViewHandlers: handler.Register(c.Handlers, c.BulkLimit, c.FailureCountUntilSkip, view, es, systemDefaults, keyChan),
|
||||
}
|
||||
spool := spoolerConfig.New()
|
||||
spool.Start()
|
||||
|
@@ -1,6 +1,8 @@
|
||||
package view
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/caos/zitadel/internal/errors"
|
||||
"github.com/caos/zitadel/internal/eventstore/models"
|
||||
key_model "github.com/caos/zitadel/internal/key/model"
|
||||
@@ -17,12 +19,21 @@ func (v *View) KeyByIDAndType(keyID string, private bool) (*model.KeyView, error
|
||||
return view.KeyByIDAndType(v.Db, keyTable, keyID, private)
|
||||
}
|
||||
|
||||
func (v *View) GetSigningKey() (*key_model.SigningKey, error) {
|
||||
key, err := view.GetSigningKey(v.Db, keyTable)
|
||||
func (v *View) GetActivePrivateKeyForSigning(expiry time.Time) (*key_model.KeyView, error) {
|
||||
key, err := view.GetSigningKey(v.Db, keyTable, expiry)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return key_model.SigningKeyFromKeyView(model.KeyViewToModel(key), v.keyAlgorithm)
|
||||
return model.KeyViewToModel(key), nil
|
||||
}
|
||||
|
||||
func (v *View) GetSigningKey(expiry time.Time) (*key_model.SigningKey, time.Time, error) {
|
||||
key, err := view.GetSigningKey(v.Db, keyTable, expiry)
|
||||
if err != nil {
|
||||
return nil, time.Time{}, err
|
||||
}
|
||||
signingKey, err := key_model.SigningKeyFromKeyView(model.KeyViewToModel(key), v.keyAlgorithm)
|
||||
return signingKey, key.Expiry, err
|
||||
}
|
||||
|
||||
func (v *View) GetActiveKeySet() ([]*key_model.PublicKey, error) {
|
||||
|
@@ -2,12 +2,11 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
)
|
||||
|
||||
type KeyRepository interface {
|
||||
GetSigningKey(ctx context.Context, keyCh chan<- jose.SigningKey, errCh chan<- error, timer <-chan time.Time)
|
||||
GetSigningKey(ctx context.Context, keyCh chan<- jose.SigningKey, algorithm string)
|
||||
GetKeySet(ctx context.Context) (*jose.JSONWebKeySet, error)
|
||||
}
|
||||
|
Reference in New Issue
Block a user