fix(query): keys (#2755)

* fix: add keys to projections

* change to multiple tables

* query keys

* query keys

* fix race condition

* fix timer reset

* begin tests

* tests

* remove migration

* only send to keyChannel if not nil
This commit is contained in:
Livio Amstutz
2022-01-12 13:22:04 +01:00
committed by GitHub
parent ead61d240d
commit 9ab566fdeb
23 changed files with 927 additions and 419 deletions

View File

@@ -9,7 +9,6 @@ import (
"github.com/caos/logging"
"github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/op"
"gopkg.in/square/go-jose.v2"
"github.com/caos/zitadel/internal/api/http/middleware"
"github.com/caos/zitadel/internal/errors"
@@ -198,16 +197,6 @@ func (o *OPStorage) RevokeToken(ctx context.Context, token, userID, clientID str
return oidc.ErrServerError().WithParent(err)
}
func (o *OPStorage) GetSigningKey(ctx context.Context, keyCh chan<- jose.SigningKey) {
o.repo.GetSigningKey(ctx, keyCh, o.signingKeyAlgorithm)
}
func (o *OPStorage) GetKeySet(ctx context.Context) (_ *jose.JSONWebKeySet, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
return o.repo.GetKeySet(ctx)
}
func (o *OPStorage) assertProjectRoleScopes(ctx context.Context, clientID string, scopes []string) ([]string, error) {
for _, scope := range scopes {
if strings.HasPrefix(scope, ScopeProjectRolePrefix) {

179
internal/api/oidc/key.go Normal file
View File

@@ -0,0 +1,179 @@
package oidc
import (
"context"
"fmt"
"time"
"github.com/caos/logging"
"gopkg.in/square/go-jose.v2"
"github.com/caos/zitadel/internal/telemetry/tracing"
"github.com/caos/zitadel/internal/crypto"
"github.com/caos/zitadel/internal/domain"
"github.com/caos/zitadel/internal/errors"
"github.com/caos/zitadel/internal/eventstore"
"github.com/caos/zitadel/internal/query"
"github.com/caos/zitadel/internal/repository/keypair"
)
const (
locksTable = "projections.locks"
signingKey = "signing_key"
)
func (o *OPStorage) GetKeySet(ctx context.Context) (_ *jose.JSONWebKeySet, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
keys, err := o.query.ActivePublicKeys(ctx, time.Now())
if err != nil {
return nil, err
}
webKeys := make([]jose.JSONWebKey, len(keys.Keys))
for i, key := range keys.Keys {
webKeys[i] = jose.JSONWebKey{
KeyID: key.ID(),
Algorithm: key.Algorithm(),
Use: key.Use().String(),
Key: key.Key(),
}
}
return &jose.JSONWebKeySet{Keys: webKeys}, nil
}
func (o *OPStorage) GetSigningKey(ctx context.Context, keyCh chan<- jose.SigningKey) {
renewTimer := time.NewTimer(0)
go func() {
for {
select {
case <-ctx.Done():
return
case <-o.keyChan:
if !renewTimer.Stop() {
<-renewTimer.C
}
checkAfter := o.resetTimer(renewTimer, true)
logging.Log("OIDC-dK432").Infof("requested next signing key check in %s", checkAfter)
case <-renewTimer.C:
o.getSigningKey(ctx, renewTimer, keyCh)
}
}
}()
}
func (o *OPStorage) getSigningKey(ctx context.Context, renewTimer *time.Timer, keyCh chan<- jose.SigningKey) {
keys, err := o.query.ActivePrivateSigningKey(ctx, time.Now().Add(o.signingKeyGracefulPeriod))
if err != nil {
checkAfter := o.resetTimer(renewTimer, true)
logging.Log("OIDC-ASff").Infof("next signing key check in %s", checkAfter)
return
}
if len(keys.Keys) == 0 {
o.refreshSigningKey(ctx, keyCh, o.signingKeyAlgorithm, keys.LatestSequence)
checkAfter := o.resetTimer(renewTimer, true)
logging.Log("OIDC-ASDf3").Infof("next signing key check in %s", checkAfter)
return
}
err = o.exchangeSigningKey(selectSigningKey(keys.Keys), keyCh)
logging.Log("OIDC-aDfg3").OnError(err).Error("could not exchange signing key")
checkAfter := o.resetTimer(renewTimer, err != nil)
logging.Log("OIDC-dK432").Infof("next signing key check in %s", checkAfter)
}
func (o *OPStorage) resetTimer(timer *time.Timer, shortRefresh bool) (nextCheck time.Duration) {
nextCheck = o.signingKeyRotationCheck
defer func() { timer.Reset(nextCheck) }()
if shortRefresh || o.currentKey == nil {
return nextCheck
}
maxLifetime := time.Until(o.currentKey.Expiry())
if maxLifetime < o.signingKeyGracefulPeriod+2*o.signingKeyRotationCheck {
return nextCheck
}
return maxLifetime - o.signingKeyGracefulPeriod - o.signingKeyRotationCheck
}
func (o *OPStorage) refreshSigningKey(ctx context.Context, keyCh chan<- jose.SigningKey, algorithm string, sequence *query.LatestSequence) {
if o.currentKey != nil && o.currentKey.Expiry().Before(time.Now().UTC()) {
logging.Log("OIDC-ADg26").Info("unset current signing key")
keyCh <- jose.SigningKey{}
}
ok, err := o.ensureIsLatestKey(ctx, sequence.Sequence)
if err != nil {
logging.Log("OIDC-sdz53").WithError(err).Error("could not ensure latest key")
return
}
if !ok {
logging.Log("EVENT-GBD23").Warn("view not up to date, retrying later")
return
}
err = o.lockAndGenerateSigningKeyPair(ctx, algorithm)
logging.Log("EVENT-B4d21").OnError(err).Warn("could not create signing key")
}
func (o *OPStorage) ensureIsLatestKey(ctx context.Context, sequence uint64) (bool, error) {
maxSequence, err := o.getMaxKeySequence(ctx)
if err != nil {
return false, fmt.Errorf("error retrieving new events: %w", err)
}
return sequence == maxSequence, nil
}
func (o *OPStorage) exchangeSigningKey(key query.PrivateKey, keyCh chan<- jose.SigningKey) (err error) {
if o.currentKey != nil && o.currentKey.ID() == key.ID() {
logging.Log("OIDC-Abb3e").Info("no new signing key")
return nil
}
keyData, err := crypto.Decrypt(key.Key(), o.encAlg)
if err != nil {
return err
}
privateKey, err := crypto.BytesToPrivateKey(keyData)
if err != nil {
return err
}
keyCh <- jose.SigningKey{
Algorithm: jose.SignatureAlgorithm(key.Algorithm()),
Key: jose.JSONWebKey{
KeyID: key.ID(),
Key: privateKey,
},
}
o.currentKey = key
logging.LogWithFields("OIDC-dsg54", "keyID", key.ID()).Info("exchanged signing key")
return nil
}
func (o *OPStorage) lockAndGenerateSigningKeyPair(ctx context.Context, algorithm string) error {
logging.Log("OIDC-sdz53").Info("lock and generate signing key pair")
ctx, cancel := context.WithCancel(ctx)
defer cancel()
errs := o.locker.Lock(ctx, o.signingKeyRotationCheck*2)
err, ok := <-errs
if err != nil || !ok {
if errors.IsErrorAlreadyExists(err) {
return nil
}
logging.Log("OIDC-Dfg32").OnError(err).Warn("initial lock failed")
return err
}
return o.command.GenerateSigningKeyPair(ctx, algorithm)
}
func (o *OPStorage) getMaxKeySequence(ctx context.Context) (uint64, error) {
return o.eventstore.LatestSequence(ctx,
eventstore.NewSearchQueryBuilder(eventstore.ColumnsMaxSequence).
ResourceOwner(domain.IAMID).
AddQuery().
AggregateTypes(keypair.AggregateType).
Builder(),
)
}
func selectSigningKey(keys []query.PrivateKey) query.PrivateKey {
return keys[len(keys)-1]
}

View File

@@ -13,8 +13,11 @@ import (
"github.com/caos/zitadel/internal/api/http/middleware"
"github.com/caos/zitadel/internal/auth/repository"
"github.com/caos/zitadel/internal/command"
"github.com/caos/zitadel/internal/config/systemdefaults"
"github.com/caos/zitadel/internal/config/types"
"github.com/caos/zitadel/internal/crypto"
"github.com/caos/zitadel/internal/eventstore"
"github.com/caos/zitadel/internal/eventstore/handler/crdb"
"github.com/caos/zitadel/internal/i18n"
"github.com/caos/zitadel/internal/id"
"github.com/caos/zitadel/internal/query"
@@ -58,18 +61,25 @@ type OPStorage struct {
repo repository.Repository
command *command.Commands
query *query.Queries
eventstore *eventstore.Eventstore
defaultLoginURL string
defaultAccessTokenLifetime time.Duration
defaultIdTokenLifetime time.Duration
signingKeyAlgorithm string
defaultRefreshTokenIdleExpiration time.Duration
defaultRefreshTokenExpiration time.Duration
encAlg crypto.EncryptionAlgorithm
keyChan <-chan interface{}
currentKey query.PrivateKey
signingKeyRotationCheck time.Duration
signingKeyGracefulPeriod time.Duration
locker crdb.Locker
}
func NewProvider(ctx context.Context, config OPHandlerConfig, command *command.Commands, query *query.Queries, repo repository.Repository, keyConfig *crypto.KeyConfig, localDevMode bool) op.OpenIDProvider {
func NewProvider(ctx context.Context, config OPHandlerConfig, command *command.Commands, query *query.Queries, repo repository.Repository, keyConfig systemdefaults.KeyConfig, localDevMode bool, es *eventstore.Eventstore, projections types.SQL, keyChan <-chan interface{}) op.OpenIDProvider {
cookieHandler, err := middleware.NewUserAgentHandler(config.UserAgentCookieConfig, id.SonyFlakeGenerator, localDevMode)
logging.Log("OIDC-sd4fd").OnError(err).WithField("traceID", tracing.TraceIDFromCtx(ctx)).Panic("cannot user agent handler")
tokenKey, err := crypto.LoadKey(keyConfig, keyConfig.EncryptionKeyID)
tokenKey, err := crypto.LoadKey(keyConfig.EncryptionConfig, keyConfig.EncryptionConfig.EncryptionKeyID)
logging.Log("OIDC-ADvbv").OnError(err).Panic("cannot load OP crypto key")
cryptoKey := []byte(tokenKey)
if len(cryptoKey) != 32 {
@@ -84,10 +94,12 @@ func NewProvider(ctx context.Context, config OPHandlerConfig, command *command.C
logging.Log("OIDC-GBd3t").OnError(err).Panic("cannot get supported languages")
config.OPConfig.SupportedUILocales = supportedLanguages
metricTypes := []metrics.MetricType{metrics.MetricTypeRequestCount, metrics.MetricTypeStatusCode, metrics.MetricTypeTotalCount}
storage, err := newStorage(config.StorageConfig, command, query, repo, keyConfig, es, projections, keyChan)
logging.Log("OIDC-Jdg2k").OnError(err).WithField("traceID", tracing.TraceIDFromCtx(ctx)).Panic("cannot create storage")
provider, err := op.NewOpenIDProvider(
ctx,
config.OPConfig,
newStorage(config.StorageConfig, command, query, repo),
storage,
op.WithHttpInterceptors(
middleware.MetricsHandler(metricTypes),
middleware.TelemetryHandler(),
@@ -107,18 +119,32 @@ func NewProvider(ctx context.Context, config OPHandlerConfig, command *command.C
return provider
}
func newStorage(config StorageConfig, command *command.Commands, query *query.Queries, repo repository.Repository) *OPStorage {
func newStorage(config StorageConfig, command *command.Commands, query *query.Queries, repo repository.Repository, keyConfig systemdefaults.KeyConfig, es *eventstore.Eventstore, projections types.SQL, keyChan <-chan interface{}) (*OPStorage, error) {
encAlg, err := crypto.NewAESCrypto(keyConfig.EncryptionConfig)
if err != nil {
return nil, err
}
sqlClient, err := projections.Start()
if err != nil {
return nil, err
}
return &OPStorage{
repo: repo,
command: command,
query: query,
eventstore: es,
defaultLoginURL: config.DefaultLoginURL,
signingKeyAlgorithm: config.SigningKeyAlgorithm,
defaultAccessTokenLifetime: config.DefaultAccessTokenLifetime.Duration,
defaultIdTokenLifetime: config.DefaultIdTokenLifetime.Duration,
defaultRefreshTokenIdleExpiration: config.DefaultRefreshTokenIdleExpiration.Duration,
defaultRefreshTokenExpiration: config.DefaultRefreshTokenExpiration.Duration,
}
encAlg: encAlg,
signingKeyGracefulPeriod: keyConfig.SigningKeyGracefulPeriod.Duration,
signingKeyRotationCheck: keyConfig.SigningKeyRotationCheck.Duration,
locker: crdb.NewLocker(sqlClient, locksTable, signingKey),
keyChan: keyChan,
}, nil
}
func (o *OPStorage) Health(ctx context.Context) error {