mirror of
https://github.com/zitadel/zitadel.git
synced 2025-10-20 09:32:34 +00:00
client and project in single query
This commit is contained in:
@@ -2,6 +2,7 @@ package oidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"slices"
|
||||
"strings"
|
||||
@@ -10,7 +11,10 @@ import (
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v3/pkg/op"
|
||||
"github.com/zitadel/zitadel/internal/command"
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
errz "github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
)
|
||||
|
||||
func (s *Server) Introspect(ctx context.Context, r *op.Request[op.IntrospectionRequest]) (_ *op.Response, err error) {
|
||||
@@ -94,43 +98,56 @@ type instrospectionClientResult struct {
|
||||
}
|
||||
|
||||
func (s *Server) instrospectionClientAuth(ctx context.Context, cc *op.ClientCredentials, rc chan<- *instrospectionClientResult) {
|
||||
clientID := cc.ClientID
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
|
||||
if cc.ClientAssertion != "" {
|
||||
verifier := op.NewJWTProfileVerifier(s.storage, op.IssuerFromContext(ctx), 1*time.Hour, time.Second)
|
||||
profile, err := op.VerifyJWTAssertion(ctx, cc.ClientAssertion, verifier)
|
||||
clientID, projectID, err := func() (string, string, error) {
|
||||
client, err := s.clientFromCredentials(ctx, cc)
|
||||
if err != nil {
|
||||
rc <- &instrospectionClientResult{
|
||||
err: oidc.ErrUnauthorizedClient().WithParent(err),
|
||||
}
|
||||
return
|
||||
return "", "", err
|
||||
}
|
||||
clientID = profile.Issuer
|
||||
} else {
|
||||
if err := s.storage.AuthorizeClientIDSecret(ctx, cc.ClientID, cc.ClientSecret); err != nil {
|
||||
if err != nil {
|
||||
rc <- &instrospectionClientResult{
|
||||
err: oidc.ErrUnauthorizedClient().WithParent(err),
|
||||
}
|
||||
return
|
||||
|
||||
if cc.ClientAssertion != "" {
|
||||
verifier := op.NewJWTProfileVerifierKeySet(keySetMap(client.PublicKeys), op.IssuerFromContext(ctx), time.Hour, time.Second)
|
||||
if _, err := op.VerifyJWTAssertion(ctx, cc.ClientAssertion, verifier); err != nil {
|
||||
return "", "", oidc.ErrUnauthorizedClient().WithParent(err)
|
||||
}
|
||||
} else {
|
||||
if err := crypto.CompareHash(client.ClientSecret, []byte(cc.ClientSecret), s.hashAlg); err != nil {
|
||||
return "", "", oidc.ErrUnauthorizedClient().WithParent(err)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
return client.ClientID, client.ProjectID, nil
|
||||
}()
|
||||
|
||||
// TODO: give clients their own aggregate, so we can skip this query
|
||||
projectID, err := s.storage.query.ProjectIDFromClientID(ctx, clientID, false)
|
||||
if err != nil {
|
||||
rc <- &instrospectionClientResult{err: err}
|
||||
return
|
||||
}
|
||||
span.EndWithError(err)
|
||||
|
||||
rc <- &instrospectionClientResult{
|
||||
clientID: clientID,
|
||||
projectID: projectID,
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// clientFromCredentials parses the client ID early,
|
||||
// and makes a single query for the client for either auth methods.
|
||||
func (s *Server) clientFromCredentials(ctx context.Context, cc *op.ClientCredentials) (client *query.IntrospectionClient, err error) {
|
||||
if cc.ClientAssertion != "" {
|
||||
claims := new(oidc.JWTTokenRequest)
|
||||
if _, err := oidc.ParseToken(cc.ClientAssertion, claims); err != nil {
|
||||
return nil, oidc.ErrUnauthorizedClient().WithParent(err)
|
||||
}
|
||||
client, err = s.storage.query.GetIntrospectionClientByID(ctx, claims.Issuer, true)
|
||||
} else {
|
||||
client, err = s.storage.query.GetIntrospectionClientByID(ctx, cc.ClientID, false)
|
||||
}
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, oidc.ErrUnauthorizedClient().WithParent(err)
|
||||
}
|
||||
// any other error is regarded internal and should not be reported back to the client.
|
||||
return client, err
|
||||
}
|
||||
|
||||
type introspectionTokenResult struct {
|
||||
tokenID string
|
||||
userID string
|
||||
|
@@ -20,14 +20,21 @@ import (
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
)
|
||||
|
||||
type keySet struct {
|
||||
// keySetCache implements oidc.KeySet for Access Token verification.
|
||||
// Public Keys are cached in a 2-dimentional map of Instance ID and Key ID.
|
||||
// When a key is not present the queryKey function is called to obtain the key
|
||||
// from the database.
|
||||
type keySetCache struct {
|
||||
mtx sync.RWMutex
|
||||
instanceKeys map[string]map[string]query.PublicKey
|
||||
queryKey func(ctx context.Context, keyID string, current time.Time) (query.PublicKey, error)
|
||||
}
|
||||
|
||||
func newKeySet(background context.Context, purgeInterval time.Duration, queryKey func(ctx context.Context, keyID string, current time.Time) (query.PublicKey, error)) *keySet {
|
||||
k := &keySet{
|
||||
// newKeySet initializes a keySetCache and starts a purging Go routine,
|
||||
// which runs once every purgeInterval.
|
||||
// When the passed context is done, the purge routine will terminate.
|
||||
func newKeySet(background context.Context, purgeInterval time.Duration, queryKey func(ctx context.Context, keyID string, current time.Time) (query.PublicKey, error)) *keySetCache {
|
||||
k := &keySetCache{
|
||||
instanceKeys: make(map[string]map[string]query.PublicKey),
|
||||
queryKey: queryKey,
|
||||
}
|
||||
@@ -35,7 +42,7 @@ func newKeySet(background context.Context, purgeInterval time.Duration, queryKey
|
||||
return k
|
||||
}
|
||||
|
||||
func (v *keySet) purgeOnInterval(background context.Context, purgeInterval time.Duration) {
|
||||
func (k *keySetCache) purgeOnInterval(background context.Context, purgeInterval time.Duration) {
|
||||
timer := time.NewTimer(purgeInterval)
|
||||
defer func() {
|
||||
if !timer.Stop() {
|
||||
@@ -43,50 +50,49 @@ func (v *keySet) purgeOnInterval(background context.Context, purgeInterval time.
|
||||
}
|
||||
}()
|
||||
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case <-background.Done():
|
||||
break loop
|
||||
return
|
||||
case <-timer.C:
|
||||
timer.Reset(purgeInterval)
|
||||
}
|
||||
|
||||
// do the actual purging
|
||||
v.mtx.Lock()
|
||||
for instanceID, keys := range v.instanceKeys {
|
||||
k.mtx.Lock()
|
||||
for instanceID, keys := range k.instanceKeys {
|
||||
for keyID, key := range keys {
|
||||
if key.Expiry().Before(time.Now()) {
|
||||
delete(keys, keyID)
|
||||
}
|
||||
}
|
||||
if len(keys) == 0 {
|
||||
delete(v.instanceKeys, instanceID)
|
||||
delete(k.instanceKeys, instanceID)
|
||||
}
|
||||
}
|
||||
v.mtx.Unlock()
|
||||
k.mtx.Unlock()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (v *keySet) setKey(instanceID, keyID string, key query.PublicKey) {
|
||||
v.mtx.Lock()
|
||||
defer v.mtx.Unlock()
|
||||
func (k *keySetCache) setKey(instanceID, keyID string, key query.PublicKey) {
|
||||
k.mtx.Lock()
|
||||
defer k.mtx.Unlock()
|
||||
|
||||
if keys, ok := v.instanceKeys[instanceID]; ok {
|
||||
if keys, ok := k.instanceKeys[instanceID]; ok {
|
||||
keys[keyID] = key
|
||||
return
|
||||
}
|
||||
|
||||
v.instanceKeys[instanceID] = map[string]query.PublicKey{keyID: key}
|
||||
k.instanceKeys[instanceID] = map[string]query.PublicKey{keyID: key}
|
||||
}
|
||||
|
||||
func (v *keySet) getKey(ctx context.Context, keyID string, current time.Time) (*jose.JSONWebKey, error) {
|
||||
func (k *keySetCache) getKey(ctx context.Context, keyID string, current time.Time) (*jose.JSONWebKey, error) {
|
||||
instanceID := authz.GetInstance(ctx).InstanceID()
|
||||
|
||||
v.mtx.RLock()
|
||||
key, ok := v.instanceKeys[instanceID][keyID]
|
||||
v.mtx.RUnlock()
|
||||
k.mtx.RLock()
|
||||
key, ok := k.instanceKeys[instanceID][keyID]
|
||||
k.mtx.RUnlock()
|
||||
|
||||
if ok {
|
||||
if key.Expiry().After(current) {
|
||||
@@ -95,24 +101,24 @@ func (v *keySet) getKey(ctx context.Context, keyID string, current time.Time) (*
|
||||
return nil, errors.ThrowInvalidArgument(nil, "OIDC-Zoh9E", "Errors.Key.ExpireBeforeNow")
|
||||
}
|
||||
|
||||
key, err := v.queryKey(ctx, keyID, current)
|
||||
key, err := k.queryKey(ctx, keyID, current)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
v.setKey(instanceID, keyID, key)
|
||||
k.setKey(instanceID, keyID, key)
|
||||
return jsonWebkey(key), nil
|
||||
}
|
||||
|
||||
// VerifySignature implements the oidc.KeySet interface.
|
||||
func (v *keySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) {
|
||||
func (k *keySetCache) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) {
|
||||
if len(jws.Signatures) != 1 {
|
||||
return nil, errors.ThrowInvalidArgument(nil, "OIDC-Gid9s", "Errors.Token.Invalid")
|
||||
}
|
||||
key, err := v.getKey(ctx, jws.Signatures[0].Header.KeyID, time.Now())
|
||||
key, err := k.getKey(ctx, jws.Signatures[0].Header.KeyID, time.Now())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return jws.Verify(&key)
|
||||
return jws.Verify(key)
|
||||
}
|
||||
|
||||
func jsonWebkey(key query.PublicKey) *jose.JSONWebKey {
|
||||
@@ -124,6 +130,35 @@ func jsonWebkey(key query.PublicKey) *jose.JSONWebKey {
|
||||
}
|
||||
}
|
||||
|
||||
// keySetMap is a mapping of key IDs to public key data.
|
||||
type keySetMap map[string][]byte
|
||||
|
||||
// getKey finds the keyID and parses the public key data
|
||||
// into a JSONWebKey.
|
||||
func (k keySetMap) getKey(keyID string) (*jose.JSONWebKey, error) {
|
||||
pubKey, err := crypto.BytesToPublicKey(k[keyID])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &jose.JSONWebKey{
|
||||
Key: pubKey,
|
||||
KeyID: keyID,
|
||||
Use: "sig",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// VerifySignature implements the oidc.KeySet interface.
|
||||
func (k keySetMap) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) {
|
||||
if len(jws.Signatures) != 1 {
|
||||
return nil, errors.ThrowInvalidArgument(nil, "OIDC-Eeth6", "Errors.Token.Invalid")
|
||||
}
|
||||
key, err := k.getKey(jws.Signatures[0].Header.KeyID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return jws.Verify(key)
|
||||
}
|
||||
|
||||
const (
|
||||
locksTable = "projections.locks"
|
||||
signingKey = "signing_key"
|
||||
|
@@ -68,7 +68,7 @@ type OPStorage struct {
|
||||
command *command.Commands
|
||||
query *query.Queries
|
||||
eventstore *eventstore.Eventstore
|
||||
keySet *keySet
|
||||
keySet *keySetCache
|
||||
defaultLoginURL string
|
||||
defaultLoginURLV2 string
|
||||
defaultLogoutURLV2 string
|
||||
@@ -122,6 +122,7 @@ func NewServer(
|
||||
server := &Server{
|
||||
storage: storage,
|
||||
LegacyServer: op.NewLegacyServer(provider, endpoints(config.CustomEndpoints)),
|
||||
hashAlg: crypto.NewBCrypt(10), // as we are only verifying in oidc, the cost is already part of the hash string and the config here is irrelevant.
|
||||
}
|
||||
metricTypes := []metrics.MetricType{metrics.MetricTypeRequestCount, metrics.MetricTypeStatusCode, metrics.MetricTypeTotalCount}
|
||||
server.Handler = op.RegisterLegacyServer(server, op.WithHTTPMiddleware(
|
||||
|
@@ -6,6 +6,7 @@ import (
|
||||
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v3/pkg/op"
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
)
|
||||
|
||||
@@ -13,6 +14,8 @@ type Server struct {
|
||||
http.Handler
|
||||
storage *OPStorage
|
||||
*op.LegacyServer
|
||||
|
||||
hashAlg crypto.HashAlgorithm
|
||||
}
|
||||
|
||||
func endpoints(endpointConfig *EndpointConfig) op.Endpoints {
|
||||
|
24
internal/query/embed/introspection_client_by_id.sql
Normal file
24
internal/query/embed/introspection_client_by_id.sql
Normal file
@@ -0,0 +1,24 @@
|
||||
with config as (
|
||||
select app_id, client_id, client_secret
|
||||
from projections.apps5_api_configs
|
||||
where instance_id = $1
|
||||
and client_id = $2
|
||||
union
|
||||
select app_id, client_id, client_secret
|
||||
from projections.apps5_oidc_configs
|
||||
where instance_id = $1
|
||||
and client_id = $2
|
||||
),
|
||||
keys as (
|
||||
select identifier as client_id, json_object_agg(id, public_key) as public_keys
|
||||
from projections.authn_keys2
|
||||
where $3 = true -- when argument is false, don't waste time on trying to query for keys.
|
||||
and instance_id = $1
|
||||
and identifier = $2
|
||||
and expiration > current_timestamp
|
||||
group by identifier
|
||||
)
|
||||
select apps.project_id, config.client_secret, keys.public_keys from config
|
||||
join projections.apps5 apps on apps.id = config.app_id
|
||||
left join keys on keys.client_id = config.client_id
|
||||
where apps.owner_removed = false;
|
45
internal/query/introspection_client.go
Normal file
45
internal/query/introspection_client.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
_ "embed"
|
||||
|
||||
"github.com/jackc/pgtype"
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
)
|
||||
|
||||
type IntrospectionClient struct {
|
||||
ClientID string
|
||||
ClientSecret *crypto.CryptoValue
|
||||
ProjectID string
|
||||
PublicKeys database.Map[[]byte]
|
||||
}
|
||||
|
||||
//go:embed embed/introspection_client_by_id.sql
|
||||
var introspectionClientByIDQuery string
|
||||
|
||||
func (q *Queries) GetIntrospectionClientByID(ctx context.Context, clientID string, getKeys bool) (_ *IntrospectionClient, err error) {
|
||||
var (
|
||||
instanceID = authz.GetInstance(ctx).InstanceID()
|
||||
client = new(IntrospectionClient)
|
||||
)
|
||||
|
||||
err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
|
||||
var publicKeys pgtype.ByteaArray
|
||||
if err := row.Scan(&client.ClientID, &client.ClientSecret, &client.ProjectID, &publicKeys); err != nil {
|
||||
return err
|
||||
}
|
||||
return publicKeys.AssignTo(&client.PublicKeys)
|
||||
},
|
||||
introspectionClientByIDQuery,
|
||||
instanceID, clientID, getKeys,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
Reference in New Issue
Block a user