get key by id and cache them

This commit is contained in:
Tim Möhlmann
2023-11-01 15:59:23 +02:00
parent 814e09f1d5
commit 85e22c1521
6 changed files with 218 additions and 21 deletions

View File

@@ -3,6 +3,7 @@ package oidc
import (
"context"
"fmt"
"sync"
"time"
"github.com/go-jose/go-jose/v3"
@@ -19,6 +20,61 @@ import (
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
type keySet struct {
mtx sync.RWMutex
keys map[string]query.PublicKey
queryKey func(ctx context.Context, keyID string, current time.Time) (query.PublicKey, error)
}
func (v *keySet) getKey(ctx context.Context, keyID string, current time.Time) (*jose.JSONWebKey, error) {
v.mtx.RLock()
key, ok := v.keys[keyID]
v.mtx.RUnlock()
if ok {
if key.Expiry().After(current) {
return jsonWebkey(key), nil
}
v.mtx.Lock()
delete(v.keys, keyID) // cleanup expired keys
v.mtx.Unlock()
return nil, errors.ThrowInvalidArgument(nil, "OIDC-Zoh9E", "Errors.Key.ExpireBeforeNow")
}
key, err := v.queryKey(ctx, keyID, current)
if err != nil {
return nil, err
}
v.mtx.Lock()
v.keys[key.ID()] = key
v.mtx.Unlock()
return jsonWebkey(key), nil
}
// VerifySignature implements the oidc.KeySet interface.
func (v *keySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) {
if len(jws.Signatures) != 1 {
return nil, errors.ThrowInvalidArgument(nil, "OIDC-Gid9s", "Errors.Token.Invalid")
}
key, err := v.getKey(ctx, jws.Signatures[0].Header.KeyID, time.Now())
if err != nil {
return nil, err
}
return jws.Verify(&key)
}
func jsonWebkey(key query.PublicKey) *jose.JSONWebKey {
return &jose.JSONWebKey{
KeyID: key.ID(),
Algorithm: key.Algorithm(),
Use: key.Use().String(),
Key: key.Key(),
}
}
const (
locksTable = "projections.locks"
signingKey = "signing_key"