improve keyset caching

This commit is contained in:
Tim Möhlmann 2023-11-02 18:55:48 +02:00
parent 9f7f715259
commit b816b6f29d
3 changed files with 69 additions and 23 deletions

View File

@ -21,24 +21,77 @@ import (
) )
type keySet struct { type keySet struct {
mtx sync.RWMutex mtx sync.RWMutex
keys map[string]query.PublicKey instanceKeys map[string]map[string]query.PublicKey
queryKey func(ctx context.Context, keyID string, current time.Time) (query.PublicKey, error) queryKey func(ctx context.Context, keyID string, current time.Time) (query.PublicKey, error)
}
func newKeySet(background context.Context, purgeInterval time.Duration, queryKey func(ctx context.Context, keyID string, current time.Time) (query.PublicKey, error)) *keySet {
k := &keySet{
instanceKeys: make(map[string]map[string]query.PublicKey),
queryKey: queryKey,
}
go k.purgeOnInterval(background, purgeInterval)
return k
}
func (v *keySet) purgeOnInterval(background context.Context, purgeInterval time.Duration) {
timer := time.NewTimer(purgeInterval)
defer func() {
if !timer.Stop() {
<-timer.C // make sure the channel is emptied
}
}()
loop:
for {
select {
case <-background.Done():
break loop
case <-timer.C:
timer.Reset(purgeInterval)
}
// do the actual purging
v.mtx.Lock()
for instanceID, keys := range v.instanceKeys {
for keyID, key := range keys {
if key.Expiry().Before(time.Now()) {
delete(keys, keyID)
}
}
if len(keys) == 0 {
delete(v.instanceKeys, instanceID)
}
}
v.mtx.Unlock()
}
}
func (v *keySet) setKey(instanceID, keyID string, key query.PublicKey) {
v.mtx.Lock()
defer v.mtx.Unlock()
if keys, ok := v.instanceKeys[instanceID]; ok {
keys[keyID] = key
return
}
v.instanceKeys[instanceID] = map[string]query.PublicKey{keyID: key}
} }
func (v *keySet) getKey(ctx context.Context, keyID string, current time.Time) (*jose.JSONWebKey, error) { func (v *keySet) getKey(ctx context.Context, keyID string, current time.Time) (*jose.JSONWebKey, error) {
instanceID := authz.GetInstance(ctx).InstanceID()
v.mtx.RLock() v.mtx.RLock()
key, ok := v.keys[keyID] key, ok := v.instanceKeys[instanceID][keyID]
v.mtx.RUnlock() v.mtx.RUnlock()
if ok { if ok {
if key.Expiry().After(current) { if key.Expiry().After(current) {
return jsonWebkey(key), nil return jsonWebkey(key), nil
} }
v.mtx.Lock()
delete(v.keys, keyID) // cleanup expired keys
v.mtx.Unlock()
return nil, errors.ThrowInvalidArgument(nil, "OIDC-Zoh9E", "Errors.Key.ExpireBeforeNow") return nil, errors.ThrowInvalidArgument(nil, "OIDC-Zoh9E", "Errors.Key.ExpireBeforeNow")
} }
@ -46,11 +99,7 @@ func (v *keySet) getKey(ctx context.Context, keyID string, current time.Time) (*
if err != nil { if err != nil {
return nil, err return nil, err
} }
v.setKey(instanceID, keyID, key)
v.mtx.Lock()
v.keys[key.ID()] = key
v.mtx.Unlock()
return jsonWebkey(key), nil return jsonWebkey(key), nil
} }

View File

@ -176,14 +176,11 @@ func createOPConfig(config Config, defaultLogoutRedirectURI string, cryptoKey []
func newStorage(config Config, command *command.Commands, queries *query.Queries, repo repository.Repository, encAlg crypto.EncryptionAlgorithm, es *eventstore.Eventstore, db *database.DB, externalSecure bool) *OPStorage { func newStorage(config Config, command *command.Commands, queries *query.Queries, repo repository.Repository, encAlg crypto.EncryptionAlgorithm, es *eventstore.Eventstore, db *database.DB, externalSecure bool) *OPStorage {
return &OPStorage{ return &OPStorage{
repo: repo, repo: repo,
command: command, command: command,
query: queries, query: queries,
eventstore: es, eventstore: es,
keySet: &keySet{ keySet: newKeySet(context.TODO(), time.Hour, queries.GetActivePublicKeyByID),
keys: make(map[string]query.PublicKey),
queryKey: queries.GetActivePublicKeyByID,
},
defaultLoginURL: fmt.Sprintf("%s%s?%s=", login.HandlerPrefix, login.EndpointLogin, login.QueryAuthRequestID), defaultLoginURL: fmt.Sprintf("%s%s?%s=", login.HandlerPrefix, login.EndpointLogin, login.QueryAuthRequestID),
defaultLoginURLV2: config.DefaultLoginURLV2, defaultLoginURLV2: config.DefaultLoginURLV2,
defaultLogoutURLV2: config.DefaultLogoutURLV2, defaultLogoutURLV2: config.DefaultLogoutURLV2,

View File

@ -395,7 +395,7 @@ func (wm *PublicKeyReadModel) Query() *eventstore.SearchQueryBuilder {
Builder() Builder()
} }
func (q *Queries) GetActivePublicKeyByID(ctx context.Context, keyID string, after time.Time) (_ PublicKey, err error) { func (q *Queries) GetActivePublicKeyByID(ctx context.Context, keyID string, current time.Time) (_ PublicKey, err error) {
model := NewPublicKeyReadModel(keyID, authz.GetInstance(ctx).InstanceID()) model := NewPublicKeyReadModel(keyID, authz.GetInstance(ctx).InstanceID())
if err := q.eventstore.FilterToQueryReducer(ctx, model); err != nil { if err := q.eventstore.FilterToQueryReducer(ctx, model); err != nil {
return nil, err return nil, err
@ -403,7 +403,7 @@ func (q *Queries) GetActivePublicKeyByID(ctx context.Context, keyID string, afte
if model.Algorithm == "" || model.Key == nil { if model.Algorithm == "" || model.Key == nil {
return nil, errors.ThrowNotFound(err, "QUERY-Ahf7x", "Errors.Key.NotFound") return nil, errors.ThrowNotFound(err, "QUERY-Ahf7x", "Errors.Key.NotFound")
} }
if model.Expiry.After(after) { if model.Expiry.Before(current) {
return nil, errors.ThrowInvalidArgument(err, "QUERY-ciF4k", "Errors.Key.ExpireBeforeNow") return nil, errors.ThrowInvalidArgument(err, "QUERY-ciF4k", "Errors.Key.ExpireBeforeNow")
} }
keyValue, err := crypto.Decrypt(model.Key, q.keyEncryptionAlgorithm) keyValue, err := crypto.Decrypt(model.Key, q.keyEncryptionAlgorithm)