get key by id and cache them

This commit is contained in:
Tim Möhlmann 2023-11-01 15:59:23 +02:00
parent 814e09f1d5
commit 85e22c1521
6 changed files with 218 additions and 21 deletions

View File

@ -3,6 +3,7 @@ package oidc
import (
"context"
"fmt"
"sync"
"time"
"github.com/go-jose/go-jose/v3"
@ -19,6 +20,61 @@ import (
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
type keySet struct {
mtx sync.RWMutex
keys map[string]query.PublicKey
queryKey func(ctx context.Context, keyID string, current time.Time) (query.PublicKey, error)
}
func (v *keySet) getKey(ctx context.Context, keyID string, current time.Time) (*jose.JSONWebKey, error) {
v.mtx.RLock()
key, ok := v.keys[keyID]
v.mtx.RUnlock()
if ok {
if key.Expiry().After(current) {
return jsonWebkey(key), nil
}
v.mtx.Lock()
delete(v.keys, keyID) // cleanup expired keys
v.mtx.Unlock()
return nil, errors.ThrowInvalidArgument(nil, "OIDC-Zoh9E", "Errors.Key.ExpireBeforeNow")
}
key, err := v.queryKey(ctx, keyID, current)
if err != nil {
return nil, err
}
v.mtx.Lock()
v.keys[key.ID()] = key
v.mtx.Unlock()
return jsonWebkey(key), nil
}
// VerifySignature implements the oidc.KeySet interface.
func (v *keySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) {
if len(jws.Signatures) != 1 {
return nil, errors.ThrowInvalidArgument(nil, "OIDC-Gid9s", "Errors.Token.Invalid")
}
key, err := v.getKey(ctx, jws.Signatures[0].Header.KeyID, time.Now())
if err != nil {
return nil, err
}
return jws.Verify(&key)
}
func jsonWebkey(key query.PublicKey) *jose.JSONWebKey {
return &jose.JSONWebKey{
KeyID: key.ID(),
Algorithm: key.Algorithm(),
Use: key.Use().String(),
Key: key.Key(),
}
}
const (
locksTable = "projections.locks"
signingKey = "signing_key"

View File

@ -68,6 +68,7 @@ type OPStorage struct {
command *command.Commands
query *query.Queries
eventstore *eventstore.Eventstore
keySet *keySet
defaultLoginURL string
defaultLoginURLV2 string
defaultLogoutURLV2 string
@ -119,6 +120,7 @@ func NewServer(
}
server := &Server{
storage: storage,
LegacyServer: op.NewLegacyServer(provider, endpoints(config.CustomEndpoints)),
}
metricTypes := []metrics.MetricType{metrics.MetricTypeRequestCount, metrics.MetricTypeStatusCode, metrics.MetricTypeTotalCount}
@ -172,12 +174,16 @@ func createOPConfig(config Config, defaultLogoutRedirectURI string, cryptoKey []
return opConfig, nil
}
func newStorage(config Config, command *command.Commands, query *query.Queries, repo repository.Repository, encAlg crypto.EncryptionAlgorithm, es *eventstore.Eventstore, db *database.DB, externalSecure bool) *OPStorage {
func newStorage(config Config, command *command.Commands, queries *query.Queries, repo repository.Repository, encAlg crypto.EncryptionAlgorithm, es *eventstore.Eventstore, db *database.DB, externalSecure bool) *OPStorage {
return &OPStorage{
repo: repo,
command: command,
query: query,
query: queries,
eventstore: es,
keySet: &keySet{
keys: make(map[string]query.PublicKey),
queryKey: queries.GetActivePublicKeyByID,
},
defaultLoginURL: fmt.Sprintf("%s%s?%s=", login.HandlerPrefix, login.EndpointLogin, login.QueryAuthRequestID),
defaultLoginURLV2: config.DefaultLoginURLV2,
defaultLogoutURLV2: config.DefaultLogoutURLV2,

View File

@ -2,7 +2,10 @@ package oidc
import (
"context"
"errors"
"net/http"
"strings"
"time"
"github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/op"
@ -11,6 +14,7 @@ import (
type Server struct {
http.Handler
storage *OPStorage
*op.LegacyServer
}
@ -159,11 +163,60 @@ func (s *Server) DeviceToken(ctx context.Context, r *op.ClientRequest[oidc.Devic
return s.LegacyServer.DeviceToken(ctx, r)
}
func (s *Server) Introspect(ctx context.Context, r *op.Request[op.IntrospectionRequest]) (_ *op.Response, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
func (s *Server) authenticateResourceClient(ctx context.Context, cc *op.ClientCredentials) (clientID string, err error) {
if cc.ClientAssertion != "" {
verifier := op.NewJWTProfileVerifier(s.storage, op.IssuerFromContext(ctx), 1*time.Hour, time.Second)
profile, err := op.VerifyJWTAssertion(ctx, cc.ClientAssertion, verifier)
if err != nil {
return "", err
}
return profile.Issuer, nil
}
return s.LegacyServer.Introspect(ctx, r)
if err = s.storage.AuthorizeClientIDSecret(ctx, cc.ClientID, cc.ClientSecret); err != nil {
if err != nil {
return "", err
}
}
return cc.ClientID, nil
}
func (s *Server) getTokenIDAndSubject(ctx context.Context, accessToken string) (idToken, subject string, err error) {
provider := s.Provider()
tokenIDSubject, err := provider.Crypto().Decrypt(accessToken)
if err == nil {
splitToken := strings.Split(tokenIDSubject, ":")
if len(splitToken) != 2 {
return "", "", errors.New("invalid token format")
}
return splitToken[0], splitToken[1], nil
}
verifier := op.NewAccessTokenVerifier(op.IssuerFromContext(ctx), s.storage.keySet)
accessTokenClaims, err := op.VerifyAccessToken[*oidc.AccessTokenClaims](ctx, accessToken, verifier)
if err != nil {
return "", "", err
}
return accessTokenClaims.JWTID, accessTokenClaims.Subject, nil
}
func (s *Server) Introspect(ctx context.Context, r *op.Request[op.IntrospectionRequest]) (_ *op.Response, err error) {
clientID, err := s.authenticateResourceClient(ctx, r.Data.ClientCredentials)
if err != nil {
return nil, err
}
response := new(oidc.IntrospectionResponse)
tokenID, subject, err := s.getTokenIDAndSubject(ctx, r.Data.Token)
if err != nil {
// TODO: log error
return op.NewResponse(response), nil
}
err = s.storage.SetIntrospectionFromToken(ctx, response, tokenID, subject, clientID)
if err != nil {
return op.NewResponse(response), nil
}
response.Active = true
return op.NewResponse(response), nil
}
func (s *Server) UserInfo(ctx context.Context, r *op.Request[oidc.UserInfoRequest]) (_ *op.Response, err error) {

View File

@ -13,7 +13,9 @@ import (
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/query/projection"
"github.com/zitadel/zitadel/internal/repository/keypair"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
@ -349,3 +351,81 @@ func preparePrivateKeysQuery(ctx context.Context, db prepareDatabase) (sq.Select
}, nil
}
}
type PublicKeyReadModel struct {
eventstore.ReadModel
Algorithm string
Key *crypto.CryptoValue
Expiry time.Time
}
func NewPublicKeyReadModel(keyID, resourceOwner string) *PublicKeyReadModel {
return &PublicKeyReadModel{
ReadModel: eventstore.ReadModel{
AggregateID: keyID,
ResourceOwner: resourceOwner,
},
}
}
func (wm *PublicKeyReadModel) AppendEvents(events ...eventstore.Event) {
wm.ReadModel.AppendEvents(events...)
}
func (wm *PublicKeyReadModel) Reduce() error {
for _, event := range wm.Events {
switch e := event.(type) {
case *keypair.AddedEvent:
wm.Algorithm = e.Algorithm
wm.Key = e.PublicKey.Key
wm.Expiry = e.PublicKey.Expiry
}
}
return wm.ReadModel.Reduce()
}
func (wm *PublicKeyReadModel) Query() *eventstore.SearchQueryBuilder {
return eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
ResourceOwner(wm.ResourceOwner).
AddQuery().
AggregateTypes(keypair.AggregateType).
AggregateIDs(wm.AggregateID).
EventTypes(keypair.AddedEventType).
Builder()
}
func (q *Queries) GetActivePublicKeyByID(ctx context.Context, keyID string, after time.Time) (_ PublicKey, err error) {
model := NewPublicKeyReadModel(keyID, authz.GetInstance(ctx).InstanceID())
if err := q.eventstore.FilterToQueryReducer(ctx, model); err != nil {
return nil, err
}
if model.Algorithm == "" || model.Key == nil {
return nil, errors.ThrowNotFound(err, "QUERY-Ahf7x", "Errors.Key.NotFound")
}
if model.Expiry.After(after) {
return nil, errors.ThrowInvalidArgument(err, "QUERY-ciF4k", "Errors.Key.ExpireBeforeNow")
}
keyValue, err := crypto.Decrypt(model.Key, q.keyEncryptionAlgorithm)
if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-Ie4oh", "Errors.Internal")
}
publicKey, err := crypto.BytesToPublicKey(keyValue)
if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-Kai2Z", "Errors.Internal")
}
return &rsaPublicKey{
key: key{
id: model.AggregateID,
creationDate: model.CreationDate,
changeDate: model.ChangeDate,
sequence: model.ProcessedSequence,
resourceOwner: model.ResourceOwner,
algorithm: model.Algorithm,
// use: , TBD, what events update this and do we need it?
},
expiry: model.Expiry,
publicKey: publicKey,
}, nil
}

View File

@ -38,6 +38,7 @@ type Queries struct {
eventstore *eventstore.Eventstore
client *database.DB
keyEncryptionAlgorithm crypto.EncryptionAlgorithm
idpConfigEncryption crypto.EncryptionAlgorithm
sessionTokenVerifier func(ctx context.Context, sessionToken string, sessionID string, tokenID string) (err error)
checkPermission domain.PermissionCheck
@ -86,7 +87,15 @@ func StartQueries(
LoginTranslationFileContents: make(map[string][]byte),
NotificationTranslationFileContents: make(map[string][]byte),
zitadelRoles: zitadelRoles,
keyEncryptionAlgorithm: keyEncryptionAlgorithm,
idpConfigEncryption: idpConfigEncryption,
sessionTokenVerifier: sessionTokenVerifier,
multifactors: domain.MultifactorConfigs{
OTP: domain.OTPConfig{
CryptoMFA: otpEncryption,
Issuer: defaults.Multifactors.OTP.Issuer,
},
},
defaultAuditLogRetention: defaultAuditLogRetention,
}
iam_repo.RegisterEventMappers(repo.eventstore)
@ -103,14 +112,6 @@ func StartQueries(
quota.RegisterEventMappers(repo.eventstore)
limits.RegisterEventMappers(repo.eventstore)
repo.idpConfigEncryption = idpConfigEncryption
repo.multifactors = domain.MultifactorConfigs{
OTP: domain.OTPConfig{
CryptoMFA: otpEncryption,
Issuer: defaults.Multifactors.OTP.Issuer,
},
}
repo.checkPermission = permissionCheck(repo)
err = projection.Create(ctx, sqlClient, es, projections, keyEncryptionAlgorithm, certEncryptionAlgorithm, systemAPIUsers)

View File

@ -428,6 +428,7 @@ Errors:
UserSession:
NotFound: UserSession not found
Key:
NotFound: Key not found
ExpireBeforeNow: The expiration date is in the past
Login:
LoginPolicy: