client and project in single query

This commit is contained in:
Tim Möhlmann
2023-11-05 13:18:17 +02:00
parent 36baf36877
commit 66f91cdc4e
6 changed files with 173 additions and 48 deletions

View File

@@ -2,6 +2,7 @@ package oidc
import ( import (
"context" "context"
"database/sql"
"errors" "errors"
"slices" "slices"
"strings" "strings"
@@ -10,7 +11,10 @@ import (
"github.com/zitadel/oidc/v3/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/op" "github.com/zitadel/oidc/v3/pkg/op"
"github.com/zitadel/zitadel/internal/command" "github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/crypto"
errz "github.com/zitadel/zitadel/internal/errors" errz "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
) )
func (s *Server) Introspect(ctx context.Context, r *op.Request[op.IntrospectionRequest]) (_ *op.Response, err error) { func (s *Server) Introspect(ctx context.Context, r *op.Request[op.IntrospectionRequest]) (_ *op.Response, err error) {
@@ -94,43 +98,56 @@ type instrospectionClientResult struct {
} }
func (s *Server) instrospectionClientAuth(ctx context.Context, cc *op.ClientCredentials, rc chan<- *instrospectionClientResult) { func (s *Server) instrospectionClientAuth(ctx context.Context, cc *op.ClientCredentials, rc chan<- *instrospectionClientResult) {
clientID := cc.ClientID ctx, span := tracing.NewSpan(ctx)
if cc.ClientAssertion != "" { clientID, projectID, err := func() (string, string, error) {
verifier := op.NewJWTProfileVerifier(s.storage, op.IssuerFromContext(ctx), 1*time.Hour, time.Second) client, err := s.clientFromCredentials(ctx, cc)
profile, err := op.VerifyJWTAssertion(ctx, cc.ClientAssertion, verifier)
if err != nil { if err != nil {
rc <- &instrospectionClientResult{ return "", "", err
err: oidc.ErrUnauthorizedClient().WithParent(err),
}
return
} }
clientID = profile.Issuer
} else { if cc.ClientAssertion != "" {
if err := s.storage.AuthorizeClientIDSecret(ctx, cc.ClientID, cc.ClientSecret); err != nil { verifier := op.NewJWTProfileVerifierKeySet(keySetMap(client.PublicKeys), op.IssuerFromContext(ctx), time.Hour, time.Second)
if err != nil { if _, err := op.VerifyJWTAssertion(ctx, cc.ClientAssertion, verifier); err != nil {
rc <- &instrospectionClientResult{ return "", "", oidc.ErrUnauthorizedClient().WithParent(err)
err: oidc.ErrUnauthorizedClient().WithParent(err), }
} } else {
return if err := crypto.CompareHash(client.ClientSecret, []byte(cc.ClientSecret), s.hashAlg); err != nil {
return "", "", oidc.ErrUnauthorizedClient().WithParent(err)
} }
} }
} return client.ClientID, client.ProjectID, nil
}()
// TODO: give clients their own aggregate, so we can skip this query span.EndWithError(err)
projectID, err := s.storage.query.ProjectIDFromClientID(ctx, clientID, false)
if err != nil {
rc <- &instrospectionClientResult{err: err}
return
}
rc <- &instrospectionClientResult{ rc <- &instrospectionClientResult{
clientID: clientID, clientID: clientID,
projectID: projectID, projectID: projectID,
err: err,
} }
} }
// clientFromCredentials parses the client ID early,
// and makes a single query for the client for either auth methods.
func (s *Server) clientFromCredentials(ctx context.Context, cc *op.ClientCredentials) (client *query.IntrospectionClient, err error) {
if cc.ClientAssertion != "" {
claims := new(oidc.JWTTokenRequest)
if _, err := oidc.ParseToken(cc.ClientAssertion, claims); err != nil {
return nil, oidc.ErrUnauthorizedClient().WithParent(err)
}
client, err = s.storage.query.GetIntrospectionClientByID(ctx, claims.Issuer, true)
} else {
client, err = s.storage.query.GetIntrospectionClientByID(ctx, cc.ClientID, false)
}
if errors.Is(err, sql.ErrNoRows) {
return nil, oidc.ErrUnauthorizedClient().WithParent(err)
}
// any other error is regarded internal and should not be reported back to the client.
return client, err
}
type introspectionTokenResult struct { type introspectionTokenResult struct {
tokenID string tokenID string
userID string userID string

View File

@@ -20,14 +20,21 @@ import (
"github.com/zitadel/zitadel/internal/telemetry/tracing" "github.com/zitadel/zitadel/internal/telemetry/tracing"
) )
type keySet struct { // keySetCache implements oidc.KeySet for Access Token verification.
// Public Keys are cached in a 2-dimentional map of Instance ID and Key ID.
// When a key is not present the queryKey function is called to obtain the key
// from the database.
type keySetCache struct {
mtx sync.RWMutex mtx sync.RWMutex
instanceKeys map[string]map[string]query.PublicKey instanceKeys map[string]map[string]query.PublicKey
queryKey func(ctx context.Context, keyID string, current time.Time) (query.PublicKey, error) queryKey func(ctx context.Context, keyID string, current time.Time) (query.PublicKey, error)
} }
func newKeySet(background context.Context, purgeInterval time.Duration, queryKey func(ctx context.Context, keyID string, current time.Time) (query.PublicKey, error)) *keySet { // newKeySet initializes a keySetCache and starts a purging Go routine,
k := &keySet{ // which runs once every purgeInterval.
// When the passed context is done, the purge routine will terminate.
func newKeySet(background context.Context, purgeInterval time.Duration, queryKey func(ctx context.Context, keyID string, current time.Time) (query.PublicKey, error)) *keySetCache {
k := &keySetCache{
instanceKeys: make(map[string]map[string]query.PublicKey), instanceKeys: make(map[string]map[string]query.PublicKey),
queryKey: queryKey, queryKey: queryKey,
} }
@@ -35,7 +42,7 @@ func newKeySet(background context.Context, purgeInterval time.Duration, queryKey
return k return k
} }
func (v *keySet) purgeOnInterval(background context.Context, purgeInterval time.Duration) { func (k *keySetCache) purgeOnInterval(background context.Context, purgeInterval time.Duration) {
timer := time.NewTimer(purgeInterval) timer := time.NewTimer(purgeInterval)
defer func() { defer func() {
if !timer.Stop() { if !timer.Stop() {
@@ -43,50 +50,49 @@ func (v *keySet) purgeOnInterval(background context.Context, purgeInterval time.
} }
}() }()
loop:
for { for {
select { select {
case <-background.Done(): case <-background.Done():
break loop return
case <-timer.C: case <-timer.C:
timer.Reset(purgeInterval) timer.Reset(purgeInterval)
} }
// do the actual purging // do the actual purging
v.mtx.Lock() k.mtx.Lock()
for instanceID, keys := range v.instanceKeys { for instanceID, keys := range k.instanceKeys {
for keyID, key := range keys { for keyID, key := range keys {
if key.Expiry().Before(time.Now()) { if key.Expiry().Before(time.Now()) {
delete(keys, keyID) delete(keys, keyID)
} }
} }
if len(keys) == 0 { if len(keys) == 0 {
delete(v.instanceKeys, instanceID) delete(k.instanceKeys, instanceID)
} }
} }
v.mtx.Unlock() k.mtx.Unlock()
} }
} }
func (v *keySet) setKey(instanceID, keyID string, key query.PublicKey) { func (k *keySetCache) setKey(instanceID, keyID string, key query.PublicKey) {
v.mtx.Lock() k.mtx.Lock()
defer v.mtx.Unlock() defer k.mtx.Unlock()
if keys, ok := v.instanceKeys[instanceID]; ok { if keys, ok := k.instanceKeys[instanceID]; ok {
keys[keyID] = key keys[keyID] = key
return return
} }
v.instanceKeys[instanceID] = map[string]query.PublicKey{keyID: key} k.instanceKeys[instanceID] = map[string]query.PublicKey{keyID: key}
} }
func (v *keySet) getKey(ctx context.Context, keyID string, current time.Time) (*jose.JSONWebKey, error) { func (k *keySetCache) getKey(ctx context.Context, keyID string, current time.Time) (*jose.JSONWebKey, error) {
instanceID := authz.GetInstance(ctx).InstanceID() instanceID := authz.GetInstance(ctx).InstanceID()
v.mtx.RLock() k.mtx.RLock()
key, ok := v.instanceKeys[instanceID][keyID] key, ok := k.instanceKeys[instanceID][keyID]
v.mtx.RUnlock() k.mtx.RUnlock()
if ok { if ok {
if key.Expiry().After(current) { if key.Expiry().After(current) {
@@ -95,24 +101,24 @@ func (v *keySet) getKey(ctx context.Context, keyID string, current time.Time) (*
return nil, errors.ThrowInvalidArgument(nil, "OIDC-Zoh9E", "Errors.Key.ExpireBeforeNow") return nil, errors.ThrowInvalidArgument(nil, "OIDC-Zoh9E", "Errors.Key.ExpireBeforeNow")
} }
key, err := v.queryKey(ctx, keyID, current) key, err := k.queryKey(ctx, keyID, current)
if err != nil { if err != nil {
return nil, err return nil, err
} }
v.setKey(instanceID, keyID, key) k.setKey(instanceID, keyID, key)
return jsonWebkey(key), nil return jsonWebkey(key), nil
} }
// VerifySignature implements the oidc.KeySet interface. // VerifySignature implements the oidc.KeySet interface.
func (v *keySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) { func (k *keySetCache) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) {
if len(jws.Signatures) != 1 { if len(jws.Signatures) != 1 {
return nil, errors.ThrowInvalidArgument(nil, "OIDC-Gid9s", "Errors.Token.Invalid") return nil, errors.ThrowInvalidArgument(nil, "OIDC-Gid9s", "Errors.Token.Invalid")
} }
key, err := v.getKey(ctx, jws.Signatures[0].Header.KeyID, time.Now()) key, err := k.getKey(ctx, jws.Signatures[0].Header.KeyID, time.Now())
if err != nil { if err != nil {
return nil, err return nil, err
} }
return jws.Verify(&key) return jws.Verify(key)
} }
func jsonWebkey(key query.PublicKey) *jose.JSONWebKey { func jsonWebkey(key query.PublicKey) *jose.JSONWebKey {
@@ -124,6 +130,35 @@ func jsonWebkey(key query.PublicKey) *jose.JSONWebKey {
} }
} }
// keySetMap is a mapping of key IDs to public key data.
type keySetMap map[string][]byte
// getKey finds the keyID and parses the public key data
// into a JSONWebKey.
func (k keySetMap) getKey(keyID string) (*jose.JSONWebKey, error) {
pubKey, err := crypto.BytesToPublicKey(k[keyID])
if err != nil {
return nil, err
}
return &jose.JSONWebKey{
Key: pubKey,
KeyID: keyID,
Use: "sig",
}, nil
}
// VerifySignature implements the oidc.KeySet interface.
func (k keySetMap) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) {
if len(jws.Signatures) != 1 {
return nil, errors.ThrowInvalidArgument(nil, "OIDC-Eeth6", "Errors.Token.Invalid")
}
key, err := k.getKey(jws.Signatures[0].Header.KeyID)
if err != nil {
return nil, err
}
return jws.Verify(key)
}
const ( const (
locksTable = "projections.locks" locksTable = "projections.locks"
signingKey = "signing_key" signingKey = "signing_key"

View File

@@ -68,7 +68,7 @@ type OPStorage struct {
command *command.Commands command *command.Commands
query *query.Queries query *query.Queries
eventstore *eventstore.Eventstore eventstore *eventstore.Eventstore
keySet *keySet keySet *keySetCache
defaultLoginURL string defaultLoginURL string
defaultLoginURLV2 string defaultLoginURLV2 string
defaultLogoutURLV2 string defaultLogoutURLV2 string
@@ -122,6 +122,7 @@ func NewServer(
server := &Server{ server := &Server{
storage: storage, storage: storage,
LegacyServer: op.NewLegacyServer(provider, endpoints(config.CustomEndpoints)), LegacyServer: op.NewLegacyServer(provider, endpoints(config.CustomEndpoints)),
hashAlg: crypto.NewBCrypt(10), // as we are only verifying in oidc, the cost is already part of the hash string and the config here is irrelevant.
} }
metricTypes := []metrics.MetricType{metrics.MetricTypeRequestCount, metrics.MetricTypeStatusCode, metrics.MetricTypeTotalCount} metricTypes := []metrics.MetricType{metrics.MetricTypeRequestCount, metrics.MetricTypeStatusCode, metrics.MetricTypeTotalCount}
server.Handler = op.RegisterLegacyServer(server, op.WithHTTPMiddleware( server.Handler = op.RegisterLegacyServer(server, op.WithHTTPMiddleware(

View File

@@ -6,6 +6,7 @@ import (
"github.com/zitadel/oidc/v3/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/op" "github.com/zitadel/oidc/v3/pkg/op"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/telemetry/tracing" "github.com/zitadel/zitadel/internal/telemetry/tracing"
) )
@@ -13,6 +14,8 @@ type Server struct {
http.Handler http.Handler
storage *OPStorage storage *OPStorage
*op.LegacyServer *op.LegacyServer
hashAlg crypto.HashAlgorithm
} }
func endpoints(endpointConfig *EndpointConfig) op.Endpoints { func endpoints(endpointConfig *EndpointConfig) op.Endpoints {

View File

@@ -0,0 +1,24 @@
with config as (
select app_id, client_id, client_secret
from projections.apps5_api_configs
where instance_id = $1
and client_id = $2
union
select app_id, client_id, client_secret
from projections.apps5_oidc_configs
where instance_id = $1
and client_id = $2
),
keys as (
select identifier as client_id, json_object_agg(id, public_key) as public_keys
from projections.authn_keys2
where $3 = true -- when argument is false, don't waste time on trying to query for keys.
and instance_id = $1
and identifier = $2
and expiration > current_timestamp
group by identifier
)
select apps.project_id, config.client_secret, keys.public_keys from config
join projections.apps5 apps on apps.id = config.app_id
left join keys on keys.client_id = config.client_id
where apps.owner_removed = false;

View File

@@ -0,0 +1,45 @@
package query
import (
"context"
"database/sql"
_ "embed"
"github.com/jackc/pgtype"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/database"
)
type IntrospectionClient struct {
ClientID string
ClientSecret *crypto.CryptoValue
ProjectID string
PublicKeys database.Map[[]byte]
}
//go:embed embed/introspection_client_by_id.sql
var introspectionClientByIDQuery string
func (q *Queries) GetIntrospectionClientByID(ctx context.Context, clientID string, getKeys bool) (_ *IntrospectionClient, err error) {
var (
instanceID = authz.GetInstance(ctx).InstanceID()
client = new(IntrospectionClient)
)
err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
var publicKeys pgtype.ByteaArray
if err := row.Scan(&client.ClientID, &client.ClientSecret, &client.ProjectID, &publicKeys); err != nil {
return err
}
return publicKeys.AssignTo(&client.PublicKeys)
},
introspectionClientByIDQuery,
instanceID, clientID, getKeys,
)
if err != nil {
return nil, err
}
return client, nil
}