mirror of
				https://github.com/zitadel/zitadel.git
				synced 2025-10-25 08:30:37 +00:00 
			
		
		
		
	get key by id and cache them
This commit is contained in:
		| @@ -3,6 +3,7 @@ package oidc | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"sync" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/go-jose/go-jose/v3" | ||||
| @@ -19,6 +20,61 @@ import ( | ||||
| 	"github.com/zitadel/zitadel/internal/telemetry/tracing" | ||||
| ) | ||||
|  | ||||
| type keySet struct { | ||||
| 	mtx      sync.RWMutex | ||||
| 	keys     map[string]query.PublicKey | ||||
| 	queryKey func(ctx context.Context, keyID string, current time.Time) (query.PublicKey, error) | ||||
| } | ||||
|  | ||||
| func (v *keySet) getKey(ctx context.Context, keyID string, current time.Time) (*jose.JSONWebKey, error) { | ||||
| 	v.mtx.RLock() | ||||
| 	key, ok := v.keys[keyID] | ||||
| 	v.mtx.RUnlock() | ||||
|  | ||||
| 	if ok { | ||||
| 		if key.Expiry().After(current) { | ||||
| 			return jsonWebkey(key), nil | ||||
| 		} | ||||
| 		v.mtx.Lock() | ||||
| 		delete(v.keys, keyID) // cleanup expired keys | ||||
| 		v.mtx.Unlock() | ||||
|  | ||||
| 		return nil, errors.ThrowInvalidArgument(nil, "OIDC-Zoh9E", "Errors.Key.ExpireBeforeNow") | ||||
| 	} | ||||
|  | ||||
| 	key, err := v.queryKey(ctx, keyID, current) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	v.mtx.Lock() | ||||
| 	v.keys[key.ID()] = key | ||||
| 	v.mtx.Unlock() | ||||
|  | ||||
| 	return jsonWebkey(key), nil | ||||
| } | ||||
|  | ||||
| // VerifySignature implements the oidc.KeySet interface. | ||||
| func (v *keySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) { | ||||
| 	if len(jws.Signatures) != 1 { | ||||
| 		return nil, errors.ThrowInvalidArgument(nil, "OIDC-Gid9s", "Errors.Token.Invalid") | ||||
| 	} | ||||
| 	key, err := v.getKey(ctx, jws.Signatures[0].Header.KeyID, time.Now()) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return jws.Verify(&key) | ||||
| } | ||||
|  | ||||
| func jsonWebkey(key query.PublicKey) *jose.JSONWebKey { | ||||
| 	return &jose.JSONWebKey{ | ||||
| 		KeyID:     key.ID(), | ||||
| 		Algorithm: key.Algorithm(), | ||||
| 		Use:       key.Use().String(), | ||||
| 		Key:       key.Key(), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| const ( | ||||
| 	locksTable = "projections.locks" | ||||
| 	signingKey = "signing_key" | ||||
|   | ||||
| @@ -68,6 +68,7 @@ type OPStorage struct { | ||||
| 	command                           *command.Commands | ||||
| 	query                             *query.Queries | ||||
| 	eventstore                        *eventstore.Eventstore | ||||
| 	keySet                            *keySet | ||||
| 	defaultLoginURL                   string | ||||
| 	defaultLoginURLV2                 string | ||||
| 	defaultLogoutURLV2                string | ||||
| @@ -119,6 +120,7 @@ func NewServer( | ||||
| 	} | ||||
|  | ||||
| 	server := &Server{ | ||||
| 		storage:      storage, | ||||
| 		LegacyServer: op.NewLegacyServer(provider, endpoints(config.CustomEndpoints)), | ||||
| 	} | ||||
| 	metricTypes := []metrics.MetricType{metrics.MetricTypeRequestCount, metrics.MetricTypeStatusCode, metrics.MetricTypeTotalCount} | ||||
| @@ -172,12 +174,16 @@ func createOPConfig(config Config, defaultLogoutRedirectURI string, cryptoKey [] | ||||
| 	return opConfig, nil | ||||
| } | ||||
|  | ||||
| func newStorage(config Config, command *command.Commands, query *query.Queries, repo repository.Repository, encAlg crypto.EncryptionAlgorithm, es *eventstore.Eventstore, db *database.DB, externalSecure bool) *OPStorage { | ||||
| func newStorage(config Config, command *command.Commands, queries *query.Queries, repo repository.Repository, encAlg crypto.EncryptionAlgorithm, es *eventstore.Eventstore, db *database.DB, externalSecure bool) *OPStorage { | ||||
| 	return &OPStorage{ | ||||
| 		repo:                              repo, | ||||
| 		command:                           command, | ||||
| 		query:                             query, | ||||
| 		eventstore:                        es, | ||||
| 		repo:       repo, | ||||
| 		command:    command, | ||||
| 		query:      queries, | ||||
| 		eventstore: es, | ||||
| 		keySet: &keySet{ | ||||
| 			keys:     make(map[string]query.PublicKey), | ||||
| 			queryKey: queries.GetActivePublicKeyByID, | ||||
| 		}, | ||||
| 		defaultLoginURL:                   fmt.Sprintf("%s%s?%s=", login.HandlerPrefix, login.EndpointLogin, login.QueryAuthRequestID), | ||||
| 		defaultLoginURLV2:                 config.DefaultLoginURLV2, | ||||
| 		defaultLogoutURLV2:                config.DefaultLogoutURLV2, | ||||
|   | ||||
| @@ -2,7 +2,10 @@ package oidc | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"errors" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/zitadel/oidc/v3/pkg/oidc" | ||||
| 	"github.com/zitadel/oidc/v3/pkg/op" | ||||
| @@ -11,6 +14,7 @@ import ( | ||||
|  | ||||
| type Server struct { | ||||
| 	http.Handler | ||||
| 	storage *OPStorage | ||||
| 	*op.LegacyServer | ||||
| } | ||||
|  | ||||
| @@ -159,11 +163,60 @@ func (s *Server) DeviceToken(ctx context.Context, r *op.ClientRequest[oidc.Devic | ||||
| 	return s.LegacyServer.DeviceToken(ctx, r) | ||||
| } | ||||
|  | ||||
| func (s *Server) Introspect(ctx context.Context, r *op.Request[op.IntrospectionRequest]) (_ *op.Response, err error) { | ||||
| 	ctx, span := tracing.NewSpan(ctx) | ||||
| 	defer func() { span.EndWithError(err) }() | ||||
| func (s *Server) authenticateResourceClient(ctx context.Context, cc *op.ClientCredentials) (clientID string, err error) { | ||||
| 	if cc.ClientAssertion != "" { | ||||
| 		verifier := op.NewJWTProfileVerifier(s.storage, op.IssuerFromContext(ctx), 1*time.Hour, time.Second) | ||||
| 		profile, err := op.VerifyJWTAssertion(ctx, cc.ClientAssertion, verifier) | ||||
| 		if err != nil { | ||||
| 			return "", err | ||||
| 		} | ||||
| 		return profile.Issuer, nil | ||||
| 	} | ||||
|  | ||||
| 	return s.LegacyServer.Introspect(ctx, r) | ||||
| 	if err = s.storage.AuthorizeClientIDSecret(ctx, cc.ClientID, cc.ClientSecret); err != nil { | ||||
| 		if err != nil { | ||||
| 			return "", err | ||||
| 		} | ||||
| 	} | ||||
| 	return cc.ClientID, nil | ||||
| } | ||||
|  | ||||
| func (s *Server) getTokenIDAndSubject(ctx context.Context, accessToken string) (idToken, subject string, err error) { | ||||
| 	provider := s.Provider() | ||||
| 	tokenIDSubject, err := provider.Crypto().Decrypt(accessToken) | ||||
| 	if err == nil { | ||||
| 		splitToken := strings.Split(tokenIDSubject, ":") | ||||
| 		if len(splitToken) != 2 { | ||||
| 			return "", "", errors.New("invalid token format") | ||||
| 		} | ||||
| 		return splitToken[0], splitToken[1], nil | ||||
| 	} | ||||
|  | ||||
| 	verifier := op.NewAccessTokenVerifier(op.IssuerFromContext(ctx), s.storage.keySet) | ||||
| 	accessTokenClaims, err := op.VerifyAccessToken[*oidc.AccessTokenClaims](ctx, accessToken, verifier) | ||||
| 	if err != nil { | ||||
| 		return "", "", err | ||||
| 	} | ||||
| 	return accessTokenClaims.JWTID, accessTokenClaims.Subject, nil | ||||
| } | ||||
|  | ||||
| func (s *Server) Introspect(ctx context.Context, r *op.Request[op.IntrospectionRequest]) (_ *op.Response, err error) { | ||||
| 	clientID, err := s.authenticateResourceClient(ctx, r.Data.ClientCredentials) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	response := new(oidc.IntrospectionResponse) | ||||
| 	tokenID, subject, err := s.getTokenIDAndSubject(ctx, r.Data.Token) | ||||
| 	if err != nil { | ||||
| 		// TODO: log error | ||||
| 		return op.NewResponse(response), nil | ||||
| 	} | ||||
| 	err = s.storage.SetIntrospectionFromToken(ctx, response, tokenID, subject, clientID) | ||||
| 	if err != nil { | ||||
| 		return op.NewResponse(response), nil | ||||
| 	} | ||||
| 	response.Active = true | ||||
| 	return op.NewResponse(response), nil | ||||
| } | ||||
|  | ||||
| func (s *Server) UserInfo(ctx context.Context, r *op.Request[oidc.UserInfoRequest]) (_ *op.Response, err error) { | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Tim Möhlmann
					Tim Möhlmann