mirror of
https://github.com/zitadel/zitadel.git
synced 2025-12-08 04:12:59 +00:00
client and project in single query
This commit is contained in:
@@ -2,6 +2,7 @@ package oidc
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -10,7 +11,10 @@ import (
|
|||||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||||
"github.com/zitadel/oidc/v3/pkg/op"
|
"github.com/zitadel/oidc/v3/pkg/op"
|
||||||
"github.com/zitadel/zitadel/internal/command"
|
"github.com/zitadel/zitadel/internal/command"
|
||||||
|
"github.com/zitadel/zitadel/internal/crypto"
|
||||||
errz "github.com/zitadel/zitadel/internal/errors"
|
errz "github.com/zitadel/zitadel/internal/errors"
|
||||||
|
"github.com/zitadel/zitadel/internal/query"
|
||||||
|
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Server) Introspect(ctx context.Context, r *op.Request[op.IntrospectionRequest]) (_ *op.Response, err error) {
|
func (s *Server) Introspect(ctx context.Context, r *op.Request[op.IntrospectionRequest]) (_ *op.Response, err error) {
|
||||||
@@ -94,43 +98,56 @@ type instrospectionClientResult struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) instrospectionClientAuth(ctx context.Context, cc *op.ClientCredentials, rc chan<- *instrospectionClientResult) {
|
func (s *Server) instrospectionClientAuth(ctx context.Context, cc *op.ClientCredentials, rc chan<- *instrospectionClientResult) {
|
||||||
clientID := cc.ClientID
|
ctx, span := tracing.NewSpan(ctx)
|
||||||
|
|
||||||
if cc.ClientAssertion != "" {
|
clientID, projectID, err := func() (string, string, error) {
|
||||||
verifier := op.NewJWTProfileVerifier(s.storage, op.IssuerFromContext(ctx), 1*time.Hour, time.Second)
|
client, err := s.clientFromCredentials(ctx, cc)
|
||||||
profile, err := op.VerifyJWTAssertion(ctx, cc.ClientAssertion, verifier)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
rc <- &instrospectionClientResult{
|
return "", "", err
|
||||||
err: oidc.ErrUnauthorizedClient().WithParent(err),
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
clientID = profile.Issuer
|
|
||||||
} else {
|
if cc.ClientAssertion != "" {
|
||||||
if err := s.storage.AuthorizeClientIDSecret(ctx, cc.ClientID, cc.ClientSecret); err != nil {
|
verifier := op.NewJWTProfileVerifierKeySet(keySetMap(client.PublicKeys), op.IssuerFromContext(ctx), time.Hour, time.Second)
|
||||||
if err != nil {
|
if _, err := op.VerifyJWTAssertion(ctx, cc.ClientAssertion, verifier); err != nil {
|
||||||
rc <- &instrospectionClientResult{
|
return "", "", oidc.ErrUnauthorizedClient().WithParent(err)
|
||||||
err: oidc.ErrUnauthorizedClient().WithParent(err),
|
}
|
||||||
}
|
} else {
|
||||||
return
|
if err := crypto.CompareHash(client.ClientSecret, []byte(cc.ClientSecret), s.hashAlg); err != nil {
|
||||||
|
return "", "", oidc.ErrUnauthorizedClient().WithParent(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
return client.ClientID, client.ProjectID, nil
|
||||||
|
}()
|
||||||
|
|
||||||
// TODO: give clients their own aggregate, so we can skip this query
|
span.EndWithError(err)
|
||||||
projectID, err := s.storage.query.ProjectIDFromClientID(ctx, clientID, false)
|
|
||||||
if err != nil {
|
|
||||||
rc <- &instrospectionClientResult{err: err}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
rc <- &instrospectionClientResult{
|
rc <- &instrospectionClientResult{
|
||||||
clientID: clientID,
|
clientID: clientID,
|
||||||
projectID: projectID,
|
projectID: projectID,
|
||||||
|
err: err,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// clientFromCredentials parses the client ID early,
|
||||||
|
// and makes a single query for the client for either auth methods.
|
||||||
|
func (s *Server) clientFromCredentials(ctx context.Context, cc *op.ClientCredentials) (client *query.IntrospectionClient, err error) {
|
||||||
|
if cc.ClientAssertion != "" {
|
||||||
|
claims := new(oidc.JWTTokenRequest)
|
||||||
|
if _, err := oidc.ParseToken(cc.ClientAssertion, claims); err != nil {
|
||||||
|
return nil, oidc.ErrUnauthorizedClient().WithParent(err)
|
||||||
|
}
|
||||||
|
client, err = s.storage.query.GetIntrospectionClientByID(ctx, claims.Issuer, true)
|
||||||
|
} else {
|
||||||
|
client, err = s.storage.query.GetIntrospectionClientByID(ctx, cc.ClientID, false)
|
||||||
|
}
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return nil, oidc.ErrUnauthorizedClient().WithParent(err)
|
||||||
|
}
|
||||||
|
// any other error is regarded internal and should not be reported back to the client.
|
||||||
|
return client, err
|
||||||
|
}
|
||||||
|
|
||||||
type introspectionTokenResult struct {
|
type introspectionTokenResult struct {
|
||||||
tokenID string
|
tokenID string
|
||||||
userID string
|
userID string
|
||||||
|
|||||||
@@ -20,14 +20,21 @@ import (
|
|||||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||||
)
|
)
|
||||||
|
|
||||||
type keySet struct {
|
// keySetCache implements oidc.KeySet for Access Token verification.
|
||||||
|
// Public Keys are cached in a 2-dimentional map of Instance ID and Key ID.
|
||||||
|
// When a key is not present the queryKey function is called to obtain the key
|
||||||
|
// from the database.
|
||||||
|
type keySetCache struct {
|
||||||
mtx sync.RWMutex
|
mtx sync.RWMutex
|
||||||
instanceKeys map[string]map[string]query.PublicKey
|
instanceKeys map[string]map[string]query.PublicKey
|
||||||
queryKey func(ctx context.Context, keyID string, current time.Time) (query.PublicKey, error)
|
queryKey func(ctx context.Context, keyID string, current time.Time) (query.PublicKey, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newKeySet(background context.Context, purgeInterval time.Duration, queryKey func(ctx context.Context, keyID string, current time.Time) (query.PublicKey, error)) *keySet {
|
// newKeySet initializes a keySetCache and starts a purging Go routine,
|
||||||
k := &keySet{
|
// which runs once every purgeInterval.
|
||||||
|
// When the passed context is done, the purge routine will terminate.
|
||||||
|
func newKeySet(background context.Context, purgeInterval time.Duration, queryKey func(ctx context.Context, keyID string, current time.Time) (query.PublicKey, error)) *keySetCache {
|
||||||
|
k := &keySetCache{
|
||||||
instanceKeys: make(map[string]map[string]query.PublicKey),
|
instanceKeys: make(map[string]map[string]query.PublicKey),
|
||||||
queryKey: queryKey,
|
queryKey: queryKey,
|
||||||
}
|
}
|
||||||
@@ -35,7 +42,7 @@ func newKeySet(background context.Context, purgeInterval time.Duration, queryKey
|
|||||||
return k
|
return k
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *keySet) purgeOnInterval(background context.Context, purgeInterval time.Duration) {
|
func (k *keySetCache) purgeOnInterval(background context.Context, purgeInterval time.Duration) {
|
||||||
timer := time.NewTimer(purgeInterval)
|
timer := time.NewTimer(purgeInterval)
|
||||||
defer func() {
|
defer func() {
|
||||||
if !timer.Stop() {
|
if !timer.Stop() {
|
||||||
@@ -43,50 +50,49 @@ func (v *keySet) purgeOnInterval(background context.Context, purgeInterval time.
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
loop:
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-background.Done():
|
case <-background.Done():
|
||||||
break loop
|
return
|
||||||
case <-timer.C:
|
case <-timer.C:
|
||||||
timer.Reset(purgeInterval)
|
timer.Reset(purgeInterval)
|
||||||
}
|
}
|
||||||
|
|
||||||
// do the actual purging
|
// do the actual purging
|
||||||
v.mtx.Lock()
|
k.mtx.Lock()
|
||||||
for instanceID, keys := range v.instanceKeys {
|
for instanceID, keys := range k.instanceKeys {
|
||||||
for keyID, key := range keys {
|
for keyID, key := range keys {
|
||||||
if key.Expiry().Before(time.Now()) {
|
if key.Expiry().Before(time.Now()) {
|
||||||
delete(keys, keyID)
|
delete(keys, keyID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(keys) == 0 {
|
if len(keys) == 0 {
|
||||||
delete(v.instanceKeys, instanceID)
|
delete(k.instanceKeys, instanceID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
v.mtx.Unlock()
|
k.mtx.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *keySet) setKey(instanceID, keyID string, key query.PublicKey) {
|
func (k *keySetCache) setKey(instanceID, keyID string, key query.PublicKey) {
|
||||||
v.mtx.Lock()
|
k.mtx.Lock()
|
||||||
defer v.mtx.Unlock()
|
defer k.mtx.Unlock()
|
||||||
|
|
||||||
if keys, ok := v.instanceKeys[instanceID]; ok {
|
if keys, ok := k.instanceKeys[instanceID]; ok {
|
||||||
keys[keyID] = key
|
keys[keyID] = key
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
v.instanceKeys[instanceID] = map[string]query.PublicKey{keyID: key}
|
k.instanceKeys[instanceID] = map[string]query.PublicKey{keyID: key}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *keySet) getKey(ctx context.Context, keyID string, current time.Time) (*jose.JSONWebKey, error) {
|
func (k *keySetCache) getKey(ctx context.Context, keyID string, current time.Time) (*jose.JSONWebKey, error) {
|
||||||
instanceID := authz.GetInstance(ctx).InstanceID()
|
instanceID := authz.GetInstance(ctx).InstanceID()
|
||||||
|
|
||||||
v.mtx.RLock()
|
k.mtx.RLock()
|
||||||
key, ok := v.instanceKeys[instanceID][keyID]
|
key, ok := k.instanceKeys[instanceID][keyID]
|
||||||
v.mtx.RUnlock()
|
k.mtx.RUnlock()
|
||||||
|
|
||||||
if ok {
|
if ok {
|
||||||
if key.Expiry().After(current) {
|
if key.Expiry().After(current) {
|
||||||
@@ -95,24 +101,24 @@ func (v *keySet) getKey(ctx context.Context, keyID string, current time.Time) (*
|
|||||||
return nil, errors.ThrowInvalidArgument(nil, "OIDC-Zoh9E", "Errors.Key.ExpireBeforeNow")
|
return nil, errors.ThrowInvalidArgument(nil, "OIDC-Zoh9E", "Errors.Key.ExpireBeforeNow")
|
||||||
}
|
}
|
||||||
|
|
||||||
key, err := v.queryKey(ctx, keyID, current)
|
key, err := k.queryKey(ctx, keyID, current)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
v.setKey(instanceID, keyID, key)
|
k.setKey(instanceID, keyID, key)
|
||||||
return jsonWebkey(key), nil
|
return jsonWebkey(key), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// VerifySignature implements the oidc.KeySet interface.
|
// VerifySignature implements the oidc.KeySet interface.
|
||||||
func (v *keySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) {
|
func (k *keySetCache) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) {
|
||||||
if len(jws.Signatures) != 1 {
|
if len(jws.Signatures) != 1 {
|
||||||
return nil, errors.ThrowInvalidArgument(nil, "OIDC-Gid9s", "Errors.Token.Invalid")
|
return nil, errors.ThrowInvalidArgument(nil, "OIDC-Gid9s", "Errors.Token.Invalid")
|
||||||
}
|
}
|
||||||
key, err := v.getKey(ctx, jws.Signatures[0].Header.KeyID, time.Now())
|
key, err := k.getKey(ctx, jws.Signatures[0].Header.KeyID, time.Now())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return jws.Verify(&key)
|
return jws.Verify(key)
|
||||||
}
|
}
|
||||||
|
|
||||||
func jsonWebkey(key query.PublicKey) *jose.JSONWebKey {
|
func jsonWebkey(key query.PublicKey) *jose.JSONWebKey {
|
||||||
@@ -124,6 +130,35 @@ func jsonWebkey(key query.PublicKey) *jose.JSONWebKey {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// keySetMap is a mapping of key IDs to public key data.
|
||||||
|
type keySetMap map[string][]byte
|
||||||
|
|
||||||
|
// getKey finds the keyID and parses the public key data
|
||||||
|
// into a JSONWebKey.
|
||||||
|
func (k keySetMap) getKey(keyID string) (*jose.JSONWebKey, error) {
|
||||||
|
pubKey, err := crypto.BytesToPublicKey(k[keyID])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &jose.JSONWebKey{
|
||||||
|
Key: pubKey,
|
||||||
|
KeyID: keyID,
|
||||||
|
Use: "sig",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// VerifySignature implements the oidc.KeySet interface.
|
||||||
|
func (k keySetMap) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) {
|
||||||
|
if len(jws.Signatures) != 1 {
|
||||||
|
return nil, errors.ThrowInvalidArgument(nil, "OIDC-Eeth6", "Errors.Token.Invalid")
|
||||||
|
}
|
||||||
|
key, err := k.getKey(jws.Signatures[0].Header.KeyID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return jws.Verify(key)
|
||||||
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
locksTable = "projections.locks"
|
locksTable = "projections.locks"
|
||||||
signingKey = "signing_key"
|
signingKey = "signing_key"
|
||||||
|
|||||||
@@ -68,7 +68,7 @@ type OPStorage struct {
|
|||||||
command *command.Commands
|
command *command.Commands
|
||||||
query *query.Queries
|
query *query.Queries
|
||||||
eventstore *eventstore.Eventstore
|
eventstore *eventstore.Eventstore
|
||||||
keySet *keySet
|
keySet *keySetCache
|
||||||
defaultLoginURL string
|
defaultLoginURL string
|
||||||
defaultLoginURLV2 string
|
defaultLoginURLV2 string
|
||||||
defaultLogoutURLV2 string
|
defaultLogoutURLV2 string
|
||||||
@@ -122,6 +122,7 @@ func NewServer(
|
|||||||
server := &Server{
|
server := &Server{
|
||||||
storage: storage,
|
storage: storage,
|
||||||
LegacyServer: op.NewLegacyServer(provider, endpoints(config.CustomEndpoints)),
|
LegacyServer: op.NewLegacyServer(provider, endpoints(config.CustomEndpoints)),
|
||||||
|
hashAlg: crypto.NewBCrypt(10), // as we are only verifying in oidc, the cost is already part of the hash string and the config here is irrelevant.
|
||||||
}
|
}
|
||||||
metricTypes := []metrics.MetricType{metrics.MetricTypeRequestCount, metrics.MetricTypeStatusCode, metrics.MetricTypeTotalCount}
|
metricTypes := []metrics.MetricType{metrics.MetricTypeRequestCount, metrics.MetricTypeStatusCode, metrics.MetricTypeTotalCount}
|
||||||
server.Handler = op.RegisterLegacyServer(server, op.WithHTTPMiddleware(
|
server.Handler = op.RegisterLegacyServer(server, op.WithHTTPMiddleware(
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
|
|
||||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||||
"github.com/zitadel/oidc/v3/pkg/op"
|
"github.com/zitadel/oidc/v3/pkg/op"
|
||||||
|
"github.com/zitadel/zitadel/internal/crypto"
|
||||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -13,6 +14,8 @@ type Server struct {
|
|||||||
http.Handler
|
http.Handler
|
||||||
storage *OPStorage
|
storage *OPStorage
|
||||||
*op.LegacyServer
|
*op.LegacyServer
|
||||||
|
|
||||||
|
hashAlg crypto.HashAlgorithm
|
||||||
}
|
}
|
||||||
|
|
||||||
func endpoints(endpointConfig *EndpointConfig) op.Endpoints {
|
func endpoints(endpointConfig *EndpointConfig) op.Endpoints {
|
||||||
|
|||||||
24
internal/query/embed/introspection_client_by_id.sql
Normal file
24
internal/query/embed/introspection_client_by_id.sql
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
with config as (
|
||||||
|
select app_id, client_id, client_secret
|
||||||
|
from projections.apps5_api_configs
|
||||||
|
where instance_id = $1
|
||||||
|
and client_id = $2
|
||||||
|
union
|
||||||
|
select app_id, client_id, client_secret
|
||||||
|
from projections.apps5_oidc_configs
|
||||||
|
where instance_id = $1
|
||||||
|
and client_id = $2
|
||||||
|
),
|
||||||
|
keys as (
|
||||||
|
select identifier as client_id, json_object_agg(id, public_key) as public_keys
|
||||||
|
from projections.authn_keys2
|
||||||
|
where $3 = true -- when argument is false, don't waste time on trying to query for keys.
|
||||||
|
and instance_id = $1
|
||||||
|
and identifier = $2
|
||||||
|
and expiration > current_timestamp
|
||||||
|
group by identifier
|
||||||
|
)
|
||||||
|
select apps.project_id, config.client_secret, keys.public_keys from config
|
||||||
|
join projections.apps5 apps on apps.id = config.app_id
|
||||||
|
left join keys on keys.client_id = config.client_id
|
||||||
|
where apps.owner_removed = false;
|
||||||
45
internal/query/introspection_client.go
Normal file
45
internal/query/introspection_client.go
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
package query
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
_ "embed"
|
||||||
|
|
||||||
|
"github.com/jackc/pgtype"
|
||||||
|
"github.com/zitadel/zitadel/internal/api/authz"
|
||||||
|
"github.com/zitadel/zitadel/internal/crypto"
|
||||||
|
"github.com/zitadel/zitadel/internal/database"
|
||||||
|
)
|
||||||
|
|
||||||
|
type IntrospectionClient struct {
|
||||||
|
ClientID string
|
||||||
|
ClientSecret *crypto.CryptoValue
|
||||||
|
ProjectID string
|
||||||
|
PublicKeys database.Map[[]byte]
|
||||||
|
}
|
||||||
|
|
||||||
|
//go:embed embed/introspection_client_by_id.sql
|
||||||
|
var introspectionClientByIDQuery string
|
||||||
|
|
||||||
|
func (q *Queries) GetIntrospectionClientByID(ctx context.Context, clientID string, getKeys bool) (_ *IntrospectionClient, err error) {
|
||||||
|
var (
|
||||||
|
instanceID = authz.GetInstance(ctx).InstanceID()
|
||||||
|
client = new(IntrospectionClient)
|
||||||
|
)
|
||||||
|
|
||||||
|
err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
|
||||||
|
var publicKeys pgtype.ByteaArray
|
||||||
|
if err := row.Scan(&client.ClientID, &client.ClientSecret, &client.ProjectID, &publicKeys); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return publicKeys.AssignTo(&client.PublicKeys)
|
||||||
|
},
|
||||||
|
introspectionClientByIDQuery,
|
||||||
|
instanceID, clientID, getKeys,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return client, nil
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user