get key by id and cache them

This commit is contained in:
Tim Möhlmann 2023-11-01 15:59:23 +02:00
parent 814e09f1d5
commit 85e22c1521
6 changed files with 218 additions and 21 deletions

View File

@ -3,6 +3,7 @@ package oidc
import (
"context"
"fmt"
"sync"
"time"
"github.com/go-jose/go-jose/v3"
@ -19,6 +20,61 @@ import (
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
type keySet struct {
mtx sync.RWMutex
keys map[string]query.PublicKey
queryKey func(ctx context.Context, keyID string, current time.Time) (query.PublicKey, error)
}
func (v *keySet) getKey(ctx context.Context, keyID string, current time.Time) (*jose.JSONWebKey, error) {
v.mtx.RLock()
key, ok := v.keys[keyID]
v.mtx.RUnlock()
if ok {
if key.Expiry().After(current) {
return jsonWebkey(key), nil
}
v.mtx.Lock()
delete(v.keys, keyID) // cleanup expired keys
v.mtx.Unlock()
return nil, errors.ThrowInvalidArgument(nil, "OIDC-Zoh9E", "Errors.Key.ExpireBeforeNow")
}
key, err := v.queryKey(ctx, keyID, current)
if err != nil {
return nil, err
}
v.mtx.Lock()
v.keys[key.ID()] = key
v.mtx.Unlock()
return jsonWebkey(key), nil
}
// VerifySignature implements the oidc.KeySet interface.
func (v *keySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) {
if len(jws.Signatures) != 1 {
return nil, errors.ThrowInvalidArgument(nil, "OIDC-Gid9s", "Errors.Token.Invalid")
}
key, err := v.getKey(ctx, jws.Signatures[0].Header.KeyID, time.Now())
if err != nil {
return nil, err
}
return jws.Verify(&key)
}
func jsonWebkey(key query.PublicKey) *jose.JSONWebKey {
return &jose.JSONWebKey{
KeyID: key.ID(),
Algorithm: key.Algorithm(),
Use: key.Use().String(),
Key: key.Key(),
}
}
const (
locksTable = "projections.locks"
signingKey = "signing_key"

View File

@ -68,6 +68,7 @@ type OPStorage struct {
command *command.Commands
query *query.Queries
eventstore *eventstore.Eventstore
keySet *keySet
defaultLoginURL string
defaultLoginURLV2 string
defaultLogoutURLV2 string
@ -119,6 +120,7 @@ func NewServer(
}
server := &Server{
storage: storage,
LegacyServer: op.NewLegacyServer(provider, endpoints(config.CustomEndpoints)),
}
metricTypes := []metrics.MetricType{metrics.MetricTypeRequestCount, metrics.MetricTypeStatusCode, metrics.MetricTypeTotalCount}
@ -172,12 +174,16 @@ func createOPConfig(config Config, defaultLogoutRedirectURI string, cryptoKey []
return opConfig, nil
}
func newStorage(config Config, command *command.Commands, query *query.Queries, repo repository.Repository, encAlg crypto.EncryptionAlgorithm, es *eventstore.Eventstore, db *database.DB, externalSecure bool) *OPStorage {
func newStorage(config Config, command *command.Commands, queries *query.Queries, repo repository.Repository, encAlg crypto.EncryptionAlgorithm, es *eventstore.Eventstore, db *database.DB, externalSecure bool) *OPStorage {
return &OPStorage{
repo: repo,
command: command,
query: query,
eventstore: es,
repo: repo,
command: command,
query: queries,
eventstore: es,
keySet: &keySet{
keys: make(map[string]query.PublicKey),
queryKey: queries.GetActivePublicKeyByID,
},
defaultLoginURL: fmt.Sprintf("%s%s?%s=", login.HandlerPrefix, login.EndpointLogin, login.QueryAuthRequestID),
defaultLoginURLV2: config.DefaultLoginURLV2,
defaultLogoutURLV2: config.DefaultLogoutURLV2,

View File

@ -2,7 +2,10 @@ package oidc
import (
"context"
"errors"
"net/http"
"strings"
"time"
"github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/op"
@ -11,6 +14,7 @@ import (
type Server struct {
http.Handler
storage *OPStorage
*op.LegacyServer
}
@ -159,11 +163,60 @@ func (s *Server) DeviceToken(ctx context.Context, r *op.ClientRequest[oidc.Devic
return s.LegacyServer.DeviceToken(ctx, r)
}
func (s *Server) Introspect(ctx context.Context, r *op.Request[op.IntrospectionRequest]) (_ *op.Response, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
func (s *Server) authenticateResourceClient(ctx context.Context, cc *op.ClientCredentials) (clientID string, err error) {
if cc.ClientAssertion != "" {
verifier := op.NewJWTProfileVerifier(s.storage, op.IssuerFromContext(ctx), 1*time.Hour, time.Second)
profile, err := op.VerifyJWTAssertion(ctx, cc.ClientAssertion, verifier)
if err != nil {
return "", err
}
return profile.Issuer, nil
}
return s.LegacyServer.Introspect(ctx, r)
if err = s.storage.AuthorizeClientIDSecret(ctx, cc.ClientID, cc.ClientSecret); err != nil {
if err != nil {
return "", err
}
}
return cc.ClientID, nil
}
func (s *Server) getTokenIDAndSubject(ctx context.Context, accessToken string) (idToken, subject string, err error) {
provider := s.Provider()
tokenIDSubject, err := provider.Crypto().Decrypt(accessToken)
if err == nil {
splitToken := strings.Split(tokenIDSubject, ":")
if len(splitToken) != 2 {
return "", "", errors.New("invalid token format")
}
return splitToken[0], splitToken[1], nil
}
verifier := op.NewAccessTokenVerifier(op.IssuerFromContext(ctx), s.storage.keySet)
accessTokenClaims, err := op.VerifyAccessToken[*oidc.AccessTokenClaims](ctx, accessToken, verifier)
if err != nil {
return "", "", err
}
return accessTokenClaims.JWTID, accessTokenClaims.Subject, nil
}
func (s *Server) Introspect(ctx context.Context, r *op.Request[op.IntrospectionRequest]) (_ *op.Response, err error) {
clientID, err := s.authenticateResourceClient(ctx, r.Data.ClientCredentials)
if err != nil {
return nil, err
}
response := new(oidc.IntrospectionResponse)
tokenID, subject, err := s.getTokenIDAndSubject(ctx, r.Data.Token)
if err != nil {
// TODO: log error
return op.NewResponse(response), nil
}
err = s.storage.SetIntrospectionFromToken(ctx, response, tokenID, subject, clientID)
if err != nil {
return op.NewResponse(response), nil
}
response.Active = true
return op.NewResponse(response), nil
}
func (s *Server) UserInfo(ctx context.Context, r *op.Request[oidc.UserInfoRequest]) (_ *op.Response, err error) {

View File

@ -13,7 +13,9 @@ import (
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/query/projection"
"github.com/zitadel/zitadel/internal/repository/keypair"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
@ -349,3 +351,81 @@ func preparePrivateKeysQuery(ctx context.Context, db prepareDatabase) (sq.Select
}, nil
}
}
type PublicKeyReadModel struct {
eventstore.ReadModel
Algorithm string
Key *crypto.CryptoValue
Expiry time.Time
}
func NewPublicKeyReadModel(keyID, resourceOwner string) *PublicKeyReadModel {
return &PublicKeyReadModel{
ReadModel: eventstore.ReadModel{
AggregateID: keyID,
ResourceOwner: resourceOwner,
},
}
}
func (wm *PublicKeyReadModel) AppendEvents(events ...eventstore.Event) {
wm.ReadModel.AppendEvents(events...)
}
func (wm *PublicKeyReadModel) Reduce() error {
for _, event := range wm.Events {
switch e := event.(type) {
case *keypair.AddedEvent:
wm.Algorithm = e.Algorithm
wm.Key = e.PublicKey.Key
wm.Expiry = e.PublicKey.Expiry
}
}
return wm.ReadModel.Reduce()
}
func (wm *PublicKeyReadModel) Query() *eventstore.SearchQueryBuilder {
return eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
ResourceOwner(wm.ResourceOwner).
AddQuery().
AggregateTypes(keypair.AggregateType).
AggregateIDs(wm.AggregateID).
EventTypes(keypair.AddedEventType).
Builder()
}
func (q *Queries) GetActivePublicKeyByID(ctx context.Context, keyID string, after time.Time) (_ PublicKey, err error) {
model := NewPublicKeyReadModel(keyID, authz.GetInstance(ctx).InstanceID())
if err := q.eventstore.FilterToQueryReducer(ctx, model); err != nil {
return nil, err
}
if model.Algorithm == "" || model.Key == nil {
return nil, errors.ThrowNotFound(err, "QUERY-Ahf7x", "Errors.Key.NotFound")
}
if model.Expiry.After(after) {
return nil, errors.ThrowInvalidArgument(err, "QUERY-ciF4k", "Errors.Key.ExpireBeforeNow")
}
keyValue, err := crypto.Decrypt(model.Key, q.keyEncryptionAlgorithm)
if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-Ie4oh", "Errors.Internal")
}
publicKey, err := crypto.BytesToPublicKey(keyValue)
if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-Kai2Z", "Errors.Internal")
}
return &rsaPublicKey{
key: key{
id: model.AggregateID,
creationDate: model.CreationDate,
changeDate: model.ChangeDate,
sequence: model.ProcessedSequence,
resourceOwner: model.ResourceOwner,
algorithm: model.Algorithm,
// use: , TBD, what events update this and do we need it?
},
expiry: model.Expiry,
publicKey: publicKey,
}, nil
}

View File

@ -38,9 +38,10 @@ type Queries struct {
eventstore *eventstore.Eventstore
client *database.DB
idpConfigEncryption crypto.EncryptionAlgorithm
sessionTokenVerifier func(ctx context.Context, sessionToken string, sessionID string, tokenID string) (err error)
checkPermission domain.PermissionCheck
keyEncryptionAlgorithm crypto.EncryptionAlgorithm
idpConfigEncryption crypto.EncryptionAlgorithm
sessionTokenVerifier func(ctx context.Context, sessionToken string, sessionID string, tokenID string) (err error)
checkPermission domain.PermissionCheck
DefaultLanguage language.Tag
LoginDir http.FileSystem
@ -86,8 +87,16 @@ func StartQueries(
LoginTranslationFileContents: make(map[string][]byte),
NotificationTranslationFileContents: make(map[string][]byte),
zitadelRoles: zitadelRoles,
keyEncryptionAlgorithm: keyEncryptionAlgorithm,
idpConfigEncryption: idpConfigEncryption,
sessionTokenVerifier: sessionTokenVerifier,
defaultAuditLogRetention: defaultAuditLogRetention,
multifactors: domain.MultifactorConfigs{
OTP: domain.OTPConfig{
CryptoMFA: otpEncryption,
Issuer: defaults.Multifactors.OTP.Issuer,
},
},
defaultAuditLogRetention: defaultAuditLogRetention,
}
iam_repo.RegisterEventMappers(repo.eventstore)
usr_repo.RegisterEventMappers(repo.eventstore)
@ -103,14 +112,6 @@ func StartQueries(
quota.RegisterEventMappers(repo.eventstore)
limits.RegisterEventMappers(repo.eventstore)
repo.idpConfigEncryption = idpConfigEncryption
repo.multifactors = domain.MultifactorConfigs{
OTP: domain.OTPConfig{
CryptoMFA: otpEncryption,
Issuer: defaults.Multifactors.OTP.Issuer,
},
}
repo.checkPermission = permissionCheck(repo)
err = projection.Create(ctx, sqlClient, es, projections, keyEncryptionAlgorithm, certEncryptionAlgorithm, systemAPIUsers)

View File

@ -428,6 +428,7 @@ Errors:
UserSession:
NotFound: UserSession not found
Key:
NotFound: Key not found
ExpireBeforeNow: The expiration date is in the past
Login:
LoginPolicy: