mirror of
https://github.com/zitadel/zitadel.git
synced 2025-10-25 08:30:37 +00:00
get key by id and cache them
This commit is contained in:
@@ -3,6 +3,7 @@ package oidc
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-jose/go-jose/v3"
|
||||
@@ -19,6 +20,61 @@ import (
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
)
|
||||
|
||||
type keySet struct {
|
||||
mtx sync.RWMutex
|
||||
keys map[string]query.PublicKey
|
||||
queryKey func(ctx context.Context, keyID string, current time.Time) (query.PublicKey, error)
|
||||
}
|
||||
|
||||
func (v *keySet) getKey(ctx context.Context, keyID string, current time.Time) (*jose.JSONWebKey, error) {
|
||||
v.mtx.RLock()
|
||||
key, ok := v.keys[keyID]
|
||||
v.mtx.RUnlock()
|
||||
|
||||
if ok {
|
||||
if key.Expiry().After(current) {
|
||||
return jsonWebkey(key), nil
|
||||
}
|
||||
v.mtx.Lock()
|
||||
delete(v.keys, keyID) // cleanup expired keys
|
||||
v.mtx.Unlock()
|
||||
|
||||
return nil, errors.ThrowInvalidArgument(nil, "OIDC-Zoh9E", "Errors.Key.ExpireBeforeNow")
|
||||
}
|
||||
|
||||
key, err := v.queryKey(ctx, keyID, current)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
v.mtx.Lock()
|
||||
v.keys[key.ID()] = key
|
||||
v.mtx.Unlock()
|
||||
|
||||
return jsonWebkey(key), nil
|
||||
}
|
||||
|
||||
// VerifySignature implements the oidc.KeySet interface.
|
||||
func (v *keySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) {
|
||||
if len(jws.Signatures) != 1 {
|
||||
return nil, errors.ThrowInvalidArgument(nil, "OIDC-Gid9s", "Errors.Token.Invalid")
|
||||
}
|
||||
key, err := v.getKey(ctx, jws.Signatures[0].Header.KeyID, time.Now())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return jws.Verify(&key)
|
||||
}
|
||||
|
||||
func jsonWebkey(key query.PublicKey) *jose.JSONWebKey {
|
||||
return &jose.JSONWebKey{
|
||||
KeyID: key.ID(),
|
||||
Algorithm: key.Algorithm(),
|
||||
Use: key.Use().String(),
|
||||
Key: key.Key(),
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
locksTable = "projections.locks"
|
||||
signingKey = "signing_key"
|
||||
|
||||
@@ -68,6 +68,7 @@ type OPStorage struct {
|
||||
command *command.Commands
|
||||
query *query.Queries
|
||||
eventstore *eventstore.Eventstore
|
||||
keySet *keySet
|
||||
defaultLoginURL string
|
||||
defaultLoginURLV2 string
|
||||
defaultLogoutURLV2 string
|
||||
@@ -119,6 +120,7 @@ func NewServer(
|
||||
}
|
||||
|
||||
server := &Server{
|
||||
storage: storage,
|
||||
LegacyServer: op.NewLegacyServer(provider, endpoints(config.CustomEndpoints)),
|
||||
}
|
||||
metricTypes := []metrics.MetricType{metrics.MetricTypeRequestCount, metrics.MetricTypeStatusCode, metrics.MetricTypeTotalCount}
|
||||
@@ -172,12 +174,16 @@ func createOPConfig(config Config, defaultLogoutRedirectURI string, cryptoKey []
|
||||
return opConfig, nil
|
||||
}
|
||||
|
||||
func newStorage(config Config, command *command.Commands, query *query.Queries, repo repository.Repository, encAlg crypto.EncryptionAlgorithm, es *eventstore.Eventstore, db *database.DB, externalSecure bool) *OPStorage {
|
||||
func newStorage(config Config, command *command.Commands, queries *query.Queries, repo repository.Repository, encAlg crypto.EncryptionAlgorithm, es *eventstore.Eventstore, db *database.DB, externalSecure bool) *OPStorage {
|
||||
return &OPStorage{
|
||||
repo: repo,
|
||||
command: command,
|
||||
query: query,
|
||||
eventstore: es,
|
||||
repo: repo,
|
||||
command: command,
|
||||
query: queries,
|
||||
eventstore: es,
|
||||
keySet: &keySet{
|
||||
keys: make(map[string]query.PublicKey),
|
||||
queryKey: queries.GetActivePublicKeyByID,
|
||||
},
|
||||
defaultLoginURL: fmt.Sprintf("%s%s?%s=", login.HandlerPrefix, login.EndpointLogin, login.QueryAuthRequestID),
|
||||
defaultLoginURLV2: config.DefaultLoginURLV2,
|
||||
defaultLogoutURLV2: config.DefaultLogoutURLV2,
|
||||
|
||||
@@ -2,7 +2,10 @@ package oidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v3/pkg/op"
|
||||
@@ -11,6 +14,7 @@ import (
|
||||
|
||||
type Server struct {
|
||||
http.Handler
|
||||
storage *OPStorage
|
||||
*op.LegacyServer
|
||||
}
|
||||
|
||||
@@ -159,11 +163,60 @@ func (s *Server) DeviceToken(ctx context.Context, r *op.ClientRequest[oidc.Devic
|
||||
return s.LegacyServer.DeviceToken(ctx, r)
|
||||
}
|
||||
|
||||
func (s *Server) Introspect(ctx context.Context, r *op.Request[op.IntrospectionRequest]) (_ *op.Response, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
func (s *Server) authenticateResourceClient(ctx context.Context, cc *op.ClientCredentials) (clientID string, err error) {
|
||||
if cc.ClientAssertion != "" {
|
||||
verifier := op.NewJWTProfileVerifier(s.storage, op.IssuerFromContext(ctx), 1*time.Hour, time.Second)
|
||||
profile, err := op.VerifyJWTAssertion(ctx, cc.ClientAssertion, verifier)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return profile.Issuer, nil
|
||||
}
|
||||
|
||||
return s.LegacyServer.Introspect(ctx, r)
|
||||
if err = s.storage.AuthorizeClientIDSecret(ctx, cc.ClientID, cc.ClientSecret); err != nil {
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
return cc.ClientID, nil
|
||||
}
|
||||
|
||||
func (s *Server) getTokenIDAndSubject(ctx context.Context, accessToken string) (idToken, subject string, err error) {
|
||||
provider := s.Provider()
|
||||
tokenIDSubject, err := provider.Crypto().Decrypt(accessToken)
|
||||
if err == nil {
|
||||
splitToken := strings.Split(tokenIDSubject, ":")
|
||||
if len(splitToken) != 2 {
|
||||
return "", "", errors.New("invalid token format")
|
||||
}
|
||||
return splitToken[0], splitToken[1], nil
|
||||
}
|
||||
|
||||
verifier := op.NewAccessTokenVerifier(op.IssuerFromContext(ctx), s.storage.keySet)
|
||||
accessTokenClaims, err := op.VerifyAccessToken[*oidc.AccessTokenClaims](ctx, accessToken, verifier)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
return accessTokenClaims.JWTID, accessTokenClaims.Subject, nil
|
||||
}
|
||||
|
||||
func (s *Server) Introspect(ctx context.Context, r *op.Request[op.IntrospectionRequest]) (_ *op.Response, err error) {
|
||||
clientID, err := s.authenticateResourceClient(ctx, r.Data.ClientCredentials)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
response := new(oidc.IntrospectionResponse)
|
||||
tokenID, subject, err := s.getTokenIDAndSubject(ctx, r.Data.Token)
|
||||
if err != nil {
|
||||
// TODO: log error
|
||||
return op.NewResponse(response), nil
|
||||
}
|
||||
err = s.storage.SetIntrospectionFromToken(ctx, response, tokenID, subject, clientID)
|
||||
if err != nil {
|
||||
return op.NewResponse(response), nil
|
||||
}
|
||||
response.Active = true
|
||||
return op.NewResponse(response), nil
|
||||
}
|
||||
|
||||
func (s *Server) UserInfo(ctx context.Context, r *op.Request[oidc.UserInfoRequest]) (_ *op.Response, err error) {
|
||||
|
||||
Reference in New Issue
Block a user