client and project in single query

This commit is contained in:
Tim Möhlmann
2023-11-05 13:18:17 +02:00
parent 36baf36877
commit 66f91cdc4e
6 changed files with 173 additions and 48 deletions

View File

@@ -20,14 +20,21 @@ import (
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
type keySet struct {
// keySetCache implements oidc.KeySet for Access Token verification.
// Public Keys are cached in a 2-dimentional map of Instance ID and Key ID.
// When a key is not present the queryKey function is called to obtain the key
// from the database.
type keySetCache struct {
mtx sync.RWMutex
instanceKeys map[string]map[string]query.PublicKey
queryKey func(ctx context.Context, keyID string, current time.Time) (query.PublicKey, error)
}
func newKeySet(background context.Context, purgeInterval time.Duration, queryKey func(ctx context.Context, keyID string, current time.Time) (query.PublicKey, error)) *keySet {
k := &keySet{
// newKeySet initializes a keySetCache and starts a purging Go routine,
// which runs once every purgeInterval.
// When the passed context is done, the purge routine will terminate.
func newKeySet(background context.Context, purgeInterval time.Duration, queryKey func(ctx context.Context, keyID string, current time.Time) (query.PublicKey, error)) *keySetCache {
k := &keySetCache{
instanceKeys: make(map[string]map[string]query.PublicKey),
queryKey: queryKey,
}
@@ -35,7 +42,7 @@ func newKeySet(background context.Context, purgeInterval time.Duration, queryKey
return k
}
func (v *keySet) purgeOnInterval(background context.Context, purgeInterval time.Duration) {
func (k *keySetCache) purgeOnInterval(background context.Context, purgeInterval time.Duration) {
timer := time.NewTimer(purgeInterval)
defer func() {
if !timer.Stop() {
@@ -43,50 +50,49 @@ func (v *keySet) purgeOnInterval(background context.Context, purgeInterval time.
}
}()
loop:
for {
select {
case <-background.Done():
break loop
return
case <-timer.C:
timer.Reset(purgeInterval)
}
// do the actual purging
v.mtx.Lock()
for instanceID, keys := range v.instanceKeys {
k.mtx.Lock()
for instanceID, keys := range k.instanceKeys {
for keyID, key := range keys {
if key.Expiry().Before(time.Now()) {
delete(keys, keyID)
}
}
if len(keys) == 0 {
delete(v.instanceKeys, instanceID)
delete(k.instanceKeys, instanceID)
}
}
v.mtx.Unlock()
k.mtx.Unlock()
}
}
func (v *keySet) setKey(instanceID, keyID string, key query.PublicKey) {
v.mtx.Lock()
defer v.mtx.Unlock()
func (k *keySetCache) setKey(instanceID, keyID string, key query.PublicKey) {
k.mtx.Lock()
defer k.mtx.Unlock()
if keys, ok := v.instanceKeys[instanceID]; ok {
if keys, ok := k.instanceKeys[instanceID]; ok {
keys[keyID] = key
return
}
v.instanceKeys[instanceID] = map[string]query.PublicKey{keyID: key}
k.instanceKeys[instanceID] = map[string]query.PublicKey{keyID: key}
}
func (v *keySet) getKey(ctx context.Context, keyID string, current time.Time) (*jose.JSONWebKey, error) {
func (k *keySetCache) getKey(ctx context.Context, keyID string, current time.Time) (*jose.JSONWebKey, error) {
instanceID := authz.GetInstance(ctx).InstanceID()
v.mtx.RLock()
key, ok := v.instanceKeys[instanceID][keyID]
v.mtx.RUnlock()
k.mtx.RLock()
key, ok := k.instanceKeys[instanceID][keyID]
k.mtx.RUnlock()
if ok {
if key.Expiry().After(current) {
@@ -95,24 +101,24 @@ func (v *keySet) getKey(ctx context.Context, keyID string, current time.Time) (*
return nil, errors.ThrowInvalidArgument(nil, "OIDC-Zoh9E", "Errors.Key.ExpireBeforeNow")
}
key, err := v.queryKey(ctx, keyID, current)
key, err := k.queryKey(ctx, keyID, current)
if err != nil {
return nil, err
}
v.setKey(instanceID, keyID, key)
k.setKey(instanceID, keyID, key)
return jsonWebkey(key), nil
}
// VerifySignature implements the oidc.KeySet interface.
func (v *keySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) {
func (k *keySetCache) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) {
if len(jws.Signatures) != 1 {
return nil, errors.ThrowInvalidArgument(nil, "OIDC-Gid9s", "Errors.Token.Invalid")
}
key, err := v.getKey(ctx, jws.Signatures[0].Header.KeyID, time.Now())
key, err := k.getKey(ctx, jws.Signatures[0].Header.KeyID, time.Now())
if err != nil {
return nil, err
}
return jws.Verify(&key)
return jws.Verify(key)
}
func jsonWebkey(key query.PublicKey) *jose.JSONWebKey {
@@ -124,6 +130,35 @@ func jsonWebkey(key query.PublicKey) *jose.JSONWebKey {
}
}
// keySetMap is a mapping of key IDs to public key data.
type keySetMap map[string][]byte
// getKey finds the keyID and parses the public key data
// into a JSONWebKey.
func (k keySetMap) getKey(keyID string) (*jose.JSONWebKey, error) {
pubKey, err := crypto.BytesToPublicKey(k[keyID])
if err != nil {
return nil, err
}
return &jose.JSONWebKey{
Key: pubKey,
KeyID: keyID,
Use: "sig",
}, nil
}
// VerifySignature implements the oidc.KeySet interface.
func (k keySetMap) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) {
if len(jws.Signatures) != 1 {
return nil, errors.ThrowInvalidArgument(nil, "OIDC-Eeth6", "Errors.Token.Invalid")
}
key, err := k.getKey(jws.Signatures[0].Header.KeyID)
if err != nil {
return nil, err
}
return jws.Verify(key)
}
const (
locksTable = "projections.locks"
signingKey = "signing_key"