mirror of
https://github.com/zitadel/zitadel.git
synced 2025-01-10 09:53:40 +00:00
improve keyset caching
This commit is contained in:
parent
9f7f715259
commit
b816b6f29d
@ -21,24 +21,77 @@ import (
|
||||
)
|
||||
|
||||
type keySet struct {
|
||||
mtx sync.RWMutex
|
||||
keys map[string]query.PublicKey
|
||||
queryKey func(ctx context.Context, keyID string, current time.Time) (query.PublicKey, error)
|
||||
mtx sync.RWMutex
|
||||
instanceKeys map[string]map[string]query.PublicKey
|
||||
queryKey func(ctx context.Context, keyID string, current time.Time) (query.PublicKey, error)
|
||||
}
|
||||
|
||||
func newKeySet(background context.Context, purgeInterval time.Duration, queryKey func(ctx context.Context, keyID string, current time.Time) (query.PublicKey, error)) *keySet {
|
||||
k := &keySet{
|
||||
instanceKeys: make(map[string]map[string]query.PublicKey),
|
||||
queryKey: queryKey,
|
||||
}
|
||||
go k.purgeOnInterval(background, purgeInterval)
|
||||
return k
|
||||
}
|
||||
|
||||
func (v *keySet) purgeOnInterval(background context.Context, purgeInterval time.Duration) {
|
||||
timer := time.NewTimer(purgeInterval)
|
||||
defer func() {
|
||||
if !timer.Stop() {
|
||||
<-timer.C // make sure the channel is emptied
|
||||
}
|
||||
}()
|
||||
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case <-background.Done():
|
||||
break loop
|
||||
case <-timer.C:
|
||||
timer.Reset(purgeInterval)
|
||||
}
|
||||
|
||||
// do the actual purging
|
||||
v.mtx.Lock()
|
||||
for instanceID, keys := range v.instanceKeys {
|
||||
for keyID, key := range keys {
|
||||
if key.Expiry().Before(time.Now()) {
|
||||
delete(keys, keyID)
|
||||
}
|
||||
}
|
||||
if len(keys) == 0 {
|
||||
delete(v.instanceKeys, instanceID)
|
||||
}
|
||||
}
|
||||
v.mtx.Unlock()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (v *keySet) setKey(instanceID, keyID string, key query.PublicKey) {
|
||||
v.mtx.Lock()
|
||||
defer v.mtx.Unlock()
|
||||
|
||||
if keys, ok := v.instanceKeys[instanceID]; ok {
|
||||
keys[keyID] = key
|
||||
return
|
||||
}
|
||||
|
||||
v.instanceKeys[instanceID] = map[string]query.PublicKey{keyID: key}
|
||||
}
|
||||
|
||||
func (v *keySet) getKey(ctx context.Context, keyID string, current time.Time) (*jose.JSONWebKey, error) {
|
||||
instanceID := authz.GetInstance(ctx).InstanceID()
|
||||
|
||||
v.mtx.RLock()
|
||||
key, ok := v.keys[keyID]
|
||||
key, ok := v.instanceKeys[instanceID][keyID]
|
||||
v.mtx.RUnlock()
|
||||
|
||||
if ok {
|
||||
if key.Expiry().After(current) {
|
||||
return jsonWebkey(key), nil
|
||||
}
|
||||
v.mtx.Lock()
|
||||
delete(v.keys, keyID) // cleanup expired keys
|
||||
v.mtx.Unlock()
|
||||
|
||||
return nil, errors.ThrowInvalidArgument(nil, "OIDC-Zoh9E", "Errors.Key.ExpireBeforeNow")
|
||||
}
|
||||
|
||||
@ -46,11 +99,7 @@ func (v *keySet) getKey(ctx context.Context, keyID string, current time.Time) (*
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
v.mtx.Lock()
|
||||
v.keys[key.ID()] = key
|
||||
v.mtx.Unlock()
|
||||
|
||||
v.setKey(instanceID, keyID, key)
|
||||
return jsonWebkey(key), nil
|
||||
}
|
||||
|
||||
|
@ -176,14 +176,11 @@ func createOPConfig(config Config, defaultLogoutRedirectURI string, cryptoKey []
|
||||
|
||||
func newStorage(config Config, command *command.Commands, queries *query.Queries, repo repository.Repository, encAlg crypto.EncryptionAlgorithm, es *eventstore.Eventstore, db *database.DB, externalSecure bool) *OPStorage {
|
||||
return &OPStorage{
|
||||
repo: repo,
|
||||
command: command,
|
||||
query: queries,
|
||||
eventstore: es,
|
||||
keySet: &keySet{
|
||||
keys: make(map[string]query.PublicKey),
|
||||
queryKey: queries.GetActivePublicKeyByID,
|
||||
},
|
||||
repo: repo,
|
||||
command: command,
|
||||
query: queries,
|
||||
eventstore: es,
|
||||
keySet: newKeySet(context.TODO(), time.Hour, queries.GetActivePublicKeyByID),
|
||||
defaultLoginURL: fmt.Sprintf("%s%s?%s=", login.HandlerPrefix, login.EndpointLogin, login.QueryAuthRequestID),
|
||||
defaultLoginURLV2: config.DefaultLoginURLV2,
|
||||
defaultLogoutURLV2: config.DefaultLogoutURLV2,
|
||||
|
@ -395,7 +395,7 @@ func (wm *PublicKeyReadModel) Query() *eventstore.SearchQueryBuilder {
|
||||
Builder()
|
||||
}
|
||||
|
||||
func (q *Queries) GetActivePublicKeyByID(ctx context.Context, keyID string, after time.Time) (_ PublicKey, err error) {
|
||||
func (q *Queries) GetActivePublicKeyByID(ctx context.Context, keyID string, current time.Time) (_ PublicKey, err error) {
|
||||
model := NewPublicKeyReadModel(keyID, authz.GetInstance(ctx).InstanceID())
|
||||
if err := q.eventstore.FilterToQueryReducer(ctx, model); err != nil {
|
||||
return nil, err
|
||||
@ -403,7 +403,7 @@ func (q *Queries) GetActivePublicKeyByID(ctx context.Context, keyID string, afte
|
||||
if model.Algorithm == "" || model.Key == nil {
|
||||
return nil, errors.ThrowNotFound(err, "QUERY-Ahf7x", "Errors.Key.NotFound")
|
||||
}
|
||||
if model.Expiry.After(after) {
|
||||
if model.Expiry.Before(current) {
|
||||
return nil, errors.ThrowInvalidArgument(err, "QUERY-ciF4k", "Errors.Key.ExpireBeforeNow")
|
||||
}
|
||||
keyValue, err := crypto.Decrypt(model.Key, q.keyEncryptionAlgorithm)
|
||||
|
Loading…
x
Reference in New Issue
Block a user