get key by id and cache them

This commit is contained in:
Tim Möhlmann 2023-11-01 15:59:23 +02:00
parent 814e09f1d5
commit 85e22c1521
6 changed files with 218 additions and 21 deletions

View File

@ -3,6 +3,7 @@ package oidc
import ( import (
"context" "context"
"fmt" "fmt"
"sync"
"time" "time"
"github.com/go-jose/go-jose/v3" "github.com/go-jose/go-jose/v3"
@ -19,6 +20,61 @@ import (
"github.com/zitadel/zitadel/internal/telemetry/tracing" "github.com/zitadel/zitadel/internal/telemetry/tracing"
) )
type keySet struct {
mtx sync.RWMutex
keys map[string]query.PublicKey
queryKey func(ctx context.Context, keyID string, current time.Time) (query.PublicKey, error)
}
func (v *keySet) getKey(ctx context.Context, keyID string, current time.Time) (*jose.JSONWebKey, error) {
v.mtx.RLock()
key, ok := v.keys[keyID]
v.mtx.RUnlock()
if ok {
if key.Expiry().After(current) {
return jsonWebkey(key), nil
}
v.mtx.Lock()
delete(v.keys, keyID) // cleanup expired keys
v.mtx.Unlock()
return nil, errors.ThrowInvalidArgument(nil, "OIDC-Zoh9E", "Errors.Key.ExpireBeforeNow")
}
key, err := v.queryKey(ctx, keyID, current)
if err != nil {
return nil, err
}
v.mtx.Lock()
v.keys[key.ID()] = key
v.mtx.Unlock()
return jsonWebkey(key), nil
}
// VerifySignature implements the oidc.KeySet interface.
func (v *keySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) {
if len(jws.Signatures) != 1 {
return nil, errors.ThrowInvalidArgument(nil, "OIDC-Gid9s", "Errors.Token.Invalid")
}
key, err := v.getKey(ctx, jws.Signatures[0].Header.KeyID, time.Now())
if err != nil {
return nil, err
}
return jws.Verify(&key)
}
func jsonWebkey(key query.PublicKey) *jose.JSONWebKey {
return &jose.JSONWebKey{
KeyID: key.ID(),
Algorithm: key.Algorithm(),
Use: key.Use().String(),
Key: key.Key(),
}
}
const ( const (
locksTable = "projections.locks" locksTable = "projections.locks"
signingKey = "signing_key" signingKey = "signing_key"

View File

@ -68,6 +68,7 @@ type OPStorage struct {
command *command.Commands command *command.Commands
query *query.Queries query *query.Queries
eventstore *eventstore.Eventstore eventstore *eventstore.Eventstore
keySet *keySet
defaultLoginURL string defaultLoginURL string
defaultLoginURLV2 string defaultLoginURLV2 string
defaultLogoutURLV2 string defaultLogoutURLV2 string
@ -119,6 +120,7 @@ func NewServer(
} }
server := &Server{ server := &Server{
storage: storage,
LegacyServer: op.NewLegacyServer(provider, endpoints(config.CustomEndpoints)), LegacyServer: op.NewLegacyServer(provider, endpoints(config.CustomEndpoints)),
} }
metricTypes := []metrics.MetricType{metrics.MetricTypeRequestCount, metrics.MetricTypeStatusCode, metrics.MetricTypeTotalCount} metricTypes := []metrics.MetricType{metrics.MetricTypeRequestCount, metrics.MetricTypeStatusCode, metrics.MetricTypeTotalCount}
@ -172,12 +174,16 @@ func createOPConfig(config Config, defaultLogoutRedirectURI string, cryptoKey []
return opConfig, nil return opConfig, nil
} }
func newStorage(config Config, command *command.Commands, query *query.Queries, repo repository.Repository, encAlg crypto.EncryptionAlgorithm, es *eventstore.Eventstore, db *database.DB, externalSecure bool) *OPStorage { func newStorage(config Config, command *command.Commands, queries *query.Queries, repo repository.Repository, encAlg crypto.EncryptionAlgorithm, es *eventstore.Eventstore, db *database.DB, externalSecure bool) *OPStorage {
return &OPStorage{ return &OPStorage{
repo: repo, repo: repo,
command: command, command: command,
query: query, query: queries,
eventstore: es, eventstore: es,
keySet: &keySet{
keys: make(map[string]query.PublicKey),
queryKey: queries.GetActivePublicKeyByID,
},
defaultLoginURL: fmt.Sprintf("%s%s?%s=", login.HandlerPrefix, login.EndpointLogin, login.QueryAuthRequestID), defaultLoginURL: fmt.Sprintf("%s%s?%s=", login.HandlerPrefix, login.EndpointLogin, login.QueryAuthRequestID),
defaultLoginURLV2: config.DefaultLoginURLV2, defaultLoginURLV2: config.DefaultLoginURLV2,
defaultLogoutURLV2: config.DefaultLogoutURLV2, defaultLogoutURLV2: config.DefaultLogoutURLV2,

View File

@ -2,7 +2,10 @@ package oidc
import ( import (
"context" "context"
"errors"
"net/http" "net/http"
"strings"
"time"
"github.com/zitadel/oidc/v3/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/op" "github.com/zitadel/oidc/v3/pkg/op"
@ -11,6 +14,7 @@ import (
type Server struct { type Server struct {
http.Handler http.Handler
storage *OPStorage
*op.LegacyServer *op.LegacyServer
} }
@ -159,11 +163,60 @@ func (s *Server) DeviceToken(ctx context.Context, r *op.ClientRequest[oidc.Devic
return s.LegacyServer.DeviceToken(ctx, r) return s.LegacyServer.DeviceToken(ctx, r)
} }
func (s *Server) Introspect(ctx context.Context, r *op.Request[op.IntrospectionRequest]) (_ *op.Response, err error) { func (s *Server) authenticateResourceClient(ctx context.Context, cc *op.ClientCredentials) (clientID string, err error) {
ctx, span := tracing.NewSpan(ctx) if cc.ClientAssertion != "" {
defer func() { span.EndWithError(err) }() verifier := op.NewJWTProfileVerifier(s.storage, op.IssuerFromContext(ctx), 1*time.Hour, time.Second)
profile, err := op.VerifyJWTAssertion(ctx, cc.ClientAssertion, verifier)
if err != nil {
return "", err
}
return profile.Issuer, nil
}
return s.LegacyServer.Introspect(ctx, r) if err = s.storage.AuthorizeClientIDSecret(ctx, cc.ClientID, cc.ClientSecret); err != nil {
if err != nil {
return "", err
}
}
return cc.ClientID, nil
}
func (s *Server) getTokenIDAndSubject(ctx context.Context, accessToken string) (idToken, subject string, err error) {
provider := s.Provider()
tokenIDSubject, err := provider.Crypto().Decrypt(accessToken)
if err == nil {
splitToken := strings.Split(tokenIDSubject, ":")
if len(splitToken) != 2 {
return "", "", errors.New("invalid token format")
}
return splitToken[0], splitToken[1], nil
}
verifier := op.NewAccessTokenVerifier(op.IssuerFromContext(ctx), s.storage.keySet)
accessTokenClaims, err := op.VerifyAccessToken[*oidc.AccessTokenClaims](ctx, accessToken, verifier)
if err != nil {
return "", "", err
}
return accessTokenClaims.JWTID, accessTokenClaims.Subject, nil
}
func (s *Server) Introspect(ctx context.Context, r *op.Request[op.IntrospectionRequest]) (_ *op.Response, err error) {
clientID, err := s.authenticateResourceClient(ctx, r.Data.ClientCredentials)
if err != nil {
return nil, err
}
response := new(oidc.IntrospectionResponse)
tokenID, subject, err := s.getTokenIDAndSubject(ctx, r.Data.Token)
if err != nil {
// TODO: log error
return op.NewResponse(response), nil
}
err = s.storage.SetIntrospectionFromToken(ctx, response, tokenID, subject, clientID)
if err != nil {
return op.NewResponse(response), nil
}
response.Active = true
return op.NewResponse(response), nil
} }
func (s *Server) UserInfo(ctx context.Context, r *op.Request[oidc.UserInfoRequest]) (_ *op.Response, err error) { func (s *Server) UserInfo(ctx context.Context, r *op.Request[oidc.UserInfoRequest]) (_ *op.Response, err error) {

View File

@ -13,7 +13,9 @@ import (
"github.com/zitadel/zitadel/internal/crypto" "github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors" "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/query/projection" "github.com/zitadel/zitadel/internal/query/projection"
"github.com/zitadel/zitadel/internal/repository/keypair"
"github.com/zitadel/zitadel/internal/telemetry/tracing" "github.com/zitadel/zitadel/internal/telemetry/tracing"
) )
@ -349,3 +351,81 @@ func preparePrivateKeysQuery(ctx context.Context, db prepareDatabase) (sq.Select
}, nil }, nil
} }
} }
type PublicKeyReadModel struct {
eventstore.ReadModel
Algorithm string
Key *crypto.CryptoValue
Expiry time.Time
}
func NewPublicKeyReadModel(keyID, resourceOwner string) *PublicKeyReadModel {
return &PublicKeyReadModel{
ReadModel: eventstore.ReadModel{
AggregateID: keyID,
ResourceOwner: resourceOwner,
},
}
}
func (wm *PublicKeyReadModel) AppendEvents(events ...eventstore.Event) {
wm.ReadModel.AppendEvents(events...)
}
func (wm *PublicKeyReadModel) Reduce() error {
for _, event := range wm.Events {
switch e := event.(type) {
case *keypair.AddedEvent:
wm.Algorithm = e.Algorithm
wm.Key = e.PublicKey.Key
wm.Expiry = e.PublicKey.Expiry
}
}
return wm.ReadModel.Reduce()
}
func (wm *PublicKeyReadModel) Query() *eventstore.SearchQueryBuilder {
return eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
ResourceOwner(wm.ResourceOwner).
AddQuery().
AggregateTypes(keypair.AggregateType).
AggregateIDs(wm.AggregateID).
EventTypes(keypair.AddedEventType).
Builder()
}
func (q *Queries) GetActivePublicKeyByID(ctx context.Context, keyID string, after time.Time) (_ PublicKey, err error) {
model := NewPublicKeyReadModel(keyID, authz.GetInstance(ctx).InstanceID())
if err := q.eventstore.FilterToQueryReducer(ctx, model); err != nil {
return nil, err
}
if model.Algorithm == "" || model.Key == nil {
return nil, errors.ThrowNotFound(err, "QUERY-Ahf7x", "Errors.Key.NotFound")
}
if model.Expiry.After(after) {
return nil, errors.ThrowInvalidArgument(err, "QUERY-ciF4k", "Errors.Key.ExpireBeforeNow")
}
keyValue, err := crypto.Decrypt(model.Key, q.keyEncryptionAlgorithm)
if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-Ie4oh", "Errors.Internal")
}
publicKey, err := crypto.BytesToPublicKey(keyValue)
if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-Kai2Z", "Errors.Internal")
}
return &rsaPublicKey{
key: key{
id: model.AggregateID,
creationDate: model.CreationDate,
changeDate: model.ChangeDate,
sequence: model.ProcessedSequence,
resourceOwner: model.ResourceOwner,
algorithm: model.Algorithm,
// use: , TBD, what events update this and do we need it?
},
expiry: model.Expiry,
publicKey: publicKey,
}, nil
}

View File

@ -38,9 +38,10 @@ type Queries struct {
eventstore *eventstore.Eventstore eventstore *eventstore.Eventstore
client *database.DB client *database.DB
idpConfigEncryption crypto.EncryptionAlgorithm keyEncryptionAlgorithm crypto.EncryptionAlgorithm
sessionTokenVerifier func(ctx context.Context, sessionToken string, sessionID string, tokenID string) (err error) idpConfigEncryption crypto.EncryptionAlgorithm
checkPermission domain.PermissionCheck sessionTokenVerifier func(ctx context.Context, sessionToken string, sessionID string, tokenID string) (err error)
checkPermission domain.PermissionCheck
DefaultLanguage language.Tag DefaultLanguage language.Tag
LoginDir http.FileSystem LoginDir http.FileSystem
@ -86,8 +87,16 @@ func StartQueries(
LoginTranslationFileContents: make(map[string][]byte), LoginTranslationFileContents: make(map[string][]byte),
NotificationTranslationFileContents: make(map[string][]byte), NotificationTranslationFileContents: make(map[string][]byte),
zitadelRoles: zitadelRoles, zitadelRoles: zitadelRoles,
keyEncryptionAlgorithm: keyEncryptionAlgorithm,
idpConfigEncryption: idpConfigEncryption,
sessionTokenVerifier: sessionTokenVerifier, sessionTokenVerifier: sessionTokenVerifier,
defaultAuditLogRetention: defaultAuditLogRetention, multifactors: domain.MultifactorConfigs{
OTP: domain.OTPConfig{
CryptoMFA: otpEncryption,
Issuer: defaults.Multifactors.OTP.Issuer,
},
},
defaultAuditLogRetention: defaultAuditLogRetention,
} }
iam_repo.RegisterEventMappers(repo.eventstore) iam_repo.RegisterEventMappers(repo.eventstore)
usr_repo.RegisterEventMappers(repo.eventstore) usr_repo.RegisterEventMappers(repo.eventstore)
@ -103,14 +112,6 @@ func StartQueries(
quota.RegisterEventMappers(repo.eventstore) quota.RegisterEventMappers(repo.eventstore)
limits.RegisterEventMappers(repo.eventstore) limits.RegisterEventMappers(repo.eventstore)
repo.idpConfigEncryption = idpConfigEncryption
repo.multifactors = domain.MultifactorConfigs{
OTP: domain.OTPConfig{
CryptoMFA: otpEncryption,
Issuer: defaults.Multifactors.OTP.Issuer,
},
}
repo.checkPermission = permissionCheck(repo) repo.checkPermission = permissionCheck(repo)
err = projection.Create(ctx, sqlClient, es, projections, keyEncryptionAlgorithm, certEncryptionAlgorithm, systemAPIUsers) err = projection.Create(ctx, sqlClient, es, projections, keyEncryptionAlgorithm, certEncryptionAlgorithm, systemAPIUsers)

View File

@ -428,6 +428,7 @@ Errors:
UserSession: UserSession:
NotFound: UserSession not found NotFound: UserSession not found
Key: Key:
NotFound: Key not found
ExpireBeforeNow: The expiration date is in the past ExpireBeforeNow: The expiration date is in the past
Login: Login:
LoginPolicy: LoginPolicy: