client and project in single query

This commit is contained in:
Tim Möhlmann
2023-11-05 13:18:17 +02:00
parent 36baf36877
commit 66f91cdc4e
6 changed files with 173 additions and 48 deletions

View File

@@ -2,6 +2,7 @@ package oidc
import (
"context"
"database/sql"
"errors"
"slices"
"strings"
@@ -10,7 +11,10 @@ import (
"github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/op"
"github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/crypto"
errz "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
func (s *Server) Introspect(ctx context.Context, r *op.Request[op.IntrospectionRequest]) (_ *op.Response, err error) {
@@ -94,43 +98,56 @@ type instrospectionClientResult struct {
}
func (s *Server) instrospectionClientAuth(ctx context.Context, cc *op.ClientCredentials, rc chan<- *instrospectionClientResult) {
clientID := cc.ClientID
ctx, span := tracing.NewSpan(ctx)
if cc.ClientAssertion != "" {
verifier := op.NewJWTProfileVerifier(s.storage, op.IssuerFromContext(ctx), 1*time.Hour, time.Second)
profile, err := op.VerifyJWTAssertion(ctx, cc.ClientAssertion, verifier)
clientID, projectID, err := func() (string, string, error) {
client, err := s.clientFromCredentials(ctx, cc)
if err != nil {
rc <- &instrospectionClientResult{
err: oidc.ErrUnauthorizedClient().WithParent(err),
}
return
return "", "", err
}
clientID = profile.Issuer
} else {
if err := s.storage.AuthorizeClientIDSecret(ctx, cc.ClientID, cc.ClientSecret); err != nil {
if err != nil {
rc <- &instrospectionClientResult{
err: oidc.ErrUnauthorizedClient().WithParent(err),
}
return
if cc.ClientAssertion != "" {
verifier := op.NewJWTProfileVerifierKeySet(keySetMap(client.PublicKeys), op.IssuerFromContext(ctx), time.Hour, time.Second)
if _, err := op.VerifyJWTAssertion(ctx, cc.ClientAssertion, verifier); err != nil {
return "", "", oidc.ErrUnauthorizedClient().WithParent(err)
}
} else {
if err := crypto.CompareHash(client.ClientSecret, []byte(cc.ClientSecret), s.hashAlg); err != nil {
return "", "", oidc.ErrUnauthorizedClient().WithParent(err)
}
}
}
return client.ClientID, client.ProjectID, nil
}()
// TODO: give clients their own aggregate, so we can skip this query
projectID, err := s.storage.query.ProjectIDFromClientID(ctx, clientID, false)
if err != nil {
rc <- &instrospectionClientResult{err: err}
return
}
span.EndWithError(err)
rc <- &instrospectionClientResult{
clientID: clientID,
projectID: projectID,
err: err,
}
}
// clientFromCredentials parses the client ID early,
// and makes a single query for the client for either auth methods.
func (s *Server) clientFromCredentials(ctx context.Context, cc *op.ClientCredentials) (client *query.IntrospectionClient, err error) {
if cc.ClientAssertion != "" {
claims := new(oidc.JWTTokenRequest)
if _, err := oidc.ParseToken(cc.ClientAssertion, claims); err != nil {
return nil, oidc.ErrUnauthorizedClient().WithParent(err)
}
client, err = s.storage.query.GetIntrospectionClientByID(ctx, claims.Issuer, true)
} else {
client, err = s.storage.query.GetIntrospectionClientByID(ctx, cc.ClientID, false)
}
if errors.Is(err, sql.ErrNoRows) {
return nil, oidc.ErrUnauthorizedClient().WithParent(err)
}
// any other error is regarded internal and should not be reported back to the client.
return client, err
}
type introspectionTokenResult struct {
tokenID string
userID string

View File

@@ -20,14 +20,21 @@ import (
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
type keySet struct {
// keySetCache implements oidc.KeySet for Access Token verification.
// Public Keys are cached in a 2-dimentional map of Instance ID and Key ID.
// When a key is not present the queryKey function is called to obtain the key
// from the database.
type keySetCache struct {
mtx sync.RWMutex
instanceKeys map[string]map[string]query.PublicKey
queryKey func(ctx context.Context, keyID string, current time.Time) (query.PublicKey, error)
}
func newKeySet(background context.Context, purgeInterval time.Duration, queryKey func(ctx context.Context, keyID string, current time.Time) (query.PublicKey, error)) *keySet {
k := &keySet{
// newKeySet initializes a keySetCache and starts a purging Go routine,
// which runs once every purgeInterval.
// When the passed context is done, the purge routine will terminate.
func newKeySet(background context.Context, purgeInterval time.Duration, queryKey func(ctx context.Context, keyID string, current time.Time) (query.PublicKey, error)) *keySetCache {
k := &keySetCache{
instanceKeys: make(map[string]map[string]query.PublicKey),
queryKey: queryKey,
}
@@ -35,7 +42,7 @@ func newKeySet(background context.Context, purgeInterval time.Duration, queryKey
return k
}
func (v *keySet) purgeOnInterval(background context.Context, purgeInterval time.Duration) {
func (k *keySetCache) purgeOnInterval(background context.Context, purgeInterval time.Duration) {
timer := time.NewTimer(purgeInterval)
defer func() {
if !timer.Stop() {
@@ -43,50 +50,49 @@ func (v *keySet) purgeOnInterval(background context.Context, purgeInterval time.
}
}()
loop:
for {
select {
case <-background.Done():
break loop
return
case <-timer.C:
timer.Reset(purgeInterval)
}
// do the actual purging
v.mtx.Lock()
for instanceID, keys := range v.instanceKeys {
k.mtx.Lock()
for instanceID, keys := range k.instanceKeys {
for keyID, key := range keys {
if key.Expiry().Before(time.Now()) {
delete(keys, keyID)
}
}
if len(keys) == 0 {
delete(v.instanceKeys, instanceID)
delete(k.instanceKeys, instanceID)
}
}
v.mtx.Unlock()
k.mtx.Unlock()
}
}
func (v *keySet) setKey(instanceID, keyID string, key query.PublicKey) {
v.mtx.Lock()
defer v.mtx.Unlock()
func (k *keySetCache) setKey(instanceID, keyID string, key query.PublicKey) {
k.mtx.Lock()
defer k.mtx.Unlock()
if keys, ok := v.instanceKeys[instanceID]; ok {
if keys, ok := k.instanceKeys[instanceID]; ok {
keys[keyID] = key
return
}
v.instanceKeys[instanceID] = map[string]query.PublicKey{keyID: key}
k.instanceKeys[instanceID] = map[string]query.PublicKey{keyID: key}
}
func (v *keySet) getKey(ctx context.Context, keyID string, current time.Time) (*jose.JSONWebKey, error) {
func (k *keySetCache) getKey(ctx context.Context, keyID string, current time.Time) (*jose.JSONWebKey, error) {
instanceID := authz.GetInstance(ctx).InstanceID()
v.mtx.RLock()
key, ok := v.instanceKeys[instanceID][keyID]
v.mtx.RUnlock()
k.mtx.RLock()
key, ok := k.instanceKeys[instanceID][keyID]
k.mtx.RUnlock()
if ok {
if key.Expiry().After(current) {
@@ -95,24 +101,24 @@ func (v *keySet) getKey(ctx context.Context, keyID string, current time.Time) (*
return nil, errors.ThrowInvalidArgument(nil, "OIDC-Zoh9E", "Errors.Key.ExpireBeforeNow")
}
key, err := v.queryKey(ctx, keyID, current)
key, err := k.queryKey(ctx, keyID, current)
if err != nil {
return nil, err
}
v.setKey(instanceID, keyID, key)
k.setKey(instanceID, keyID, key)
return jsonWebkey(key), nil
}
// VerifySignature implements the oidc.KeySet interface.
func (v *keySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) {
func (k *keySetCache) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) {
if len(jws.Signatures) != 1 {
return nil, errors.ThrowInvalidArgument(nil, "OIDC-Gid9s", "Errors.Token.Invalid")
}
key, err := v.getKey(ctx, jws.Signatures[0].Header.KeyID, time.Now())
key, err := k.getKey(ctx, jws.Signatures[0].Header.KeyID, time.Now())
if err != nil {
return nil, err
}
return jws.Verify(&key)
return jws.Verify(key)
}
func jsonWebkey(key query.PublicKey) *jose.JSONWebKey {
@@ -124,6 +130,35 @@ func jsonWebkey(key query.PublicKey) *jose.JSONWebKey {
}
}
// keySetMap is a mapping of key IDs to public key data.
type keySetMap map[string][]byte
// getKey finds the keyID and parses the public key data
// into a JSONWebKey.
func (k keySetMap) getKey(keyID string) (*jose.JSONWebKey, error) {
pubKey, err := crypto.BytesToPublicKey(k[keyID])
if err != nil {
return nil, err
}
return &jose.JSONWebKey{
Key: pubKey,
KeyID: keyID,
Use: "sig",
}, nil
}
// VerifySignature implements the oidc.KeySet interface.
func (k keySetMap) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) {
if len(jws.Signatures) != 1 {
return nil, errors.ThrowInvalidArgument(nil, "OIDC-Eeth6", "Errors.Token.Invalid")
}
key, err := k.getKey(jws.Signatures[0].Header.KeyID)
if err != nil {
return nil, err
}
return jws.Verify(key)
}
const (
locksTable = "projections.locks"
signingKey = "signing_key"

View File

@@ -68,7 +68,7 @@ type OPStorage struct {
command *command.Commands
query *query.Queries
eventstore *eventstore.Eventstore
keySet *keySet
keySet *keySetCache
defaultLoginURL string
defaultLoginURLV2 string
defaultLogoutURLV2 string
@@ -122,6 +122,7 @@ func NewServer(
server := &Server{
storage: storage,
LegacyServer: op.NewLegacyServer(provider, endpoints(config.CustomEndpoints)),
hashAlg: crypto.NewBCrypt(10), // as we are only verifying in oidc, the cost is already part of the hash string and the config here is irrelevant.
}
metricTypes := []metrics.MetricType{metrics.MetricTypeRequestCount, metrics.MetricTypeStatusCode, metrics.MetricTypeTotalCount}
server.Handler = op.RegisterLegacyServer(server, op.WithHTTPMiddleware(

View File

@@ -6,6 +6,7 @@ import (
"github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/op"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
@@ -13,6 +14,8 @@ type Server struct {
http.Handler
storage *OPStorage
*op.LegacyServer
hashAlg crypto.HashAlgorithm
}
func endpoints(endpointConfig *EndpointConfig) op.Endpoints {

View File

@@ -0,0 +1,24 @@
with config as (
select app_id, client_id, client_secret
from projections.apps5_api_configs
where instance_id = $1
and client_id = $2
union
select app_id, client_id, client_secret
from projections.apps5_oidc_configs
where instance_id = $1
and client_id = $2
),
keys as (
select identifier as client_id, json_object_agg(id, public_key) as public_keys
from projections.authn_keys2
where $3 = true -- when argument is false, don't waste time on trying to query for keys.
and instance_id = $1
and identifier = $2
and expiration > current_timestamp
group by identifier
)
select apps.project_id, config.client_secret, keys.public_keys from config
join projections.apps5 apps on apps.id = config.app_id
left join keys on keys.client_id = config.client_id
where apps.owner_removed = false;

View File

@@ -0,0 +1,45 @@
package query
import (
"context"
"database/sql"
_ "embed"
"github.com/jackc/pgtype"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/database"
)
type IntrospectionClient struct {
ClientID string
ClientSecret *crypto.CryptoValue
ProjectID string
PublicKeys database.Map[[]byte]
}
//go:embed embed/introspection_client_by_id.sql
var introspectionClientByIDQuery string
func (q *Queries) GetIntrospectionClientByID(ctx context.Context, clientID string, getKeys bool) (_ *IntrospectionClient, err error) {
var (
instanceID = authz.GetInstance(ctx).InstanceID()
client = new(IntrospectionClient)
)
err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
var publicKeys pgtype.ByteaArray
if err := row.Scan(&client.ClientID, &client.ClientSecret, &client.ProjectID, &publicKeys); err != nil {
return err
}
return publicKeys.AssignTo(&client.PublicKeys)
},
introspectionClientByIDQuery,
instanceID, clientID, getKeys,
)
if err != nil {
return nil, err
}
return client, nil
}