perf(oidc): optimize the introspection endpoint (#6909)

* get key by id and cache them

* userinfo from events for v2 tokens

* improve keyset caching

* concurrent token and client checks

* client and project in single query

* logging and otel

* drop owner_removed column on apps and authN tables

* userinfo and project roles in go routines

* get  oidc user info from projections and add actions

* add avatar URL

* some cleanup

* pull oidc work branch

* remove storage from server

* add config flag for experimental introspection

* legacy introspection flag

* drop owner_removed column on user projections

* drop owner_removed column on useer_metadata

* query userinfo unit test

* query introspection client test

* add user_grants to the userinfo query

* handle PAT scopes

* bring triggers back

* test instance keys query

* add userinfo unit tests

* unit test keys

* go mod tidy

* solve some bugs

* fix missing preferred login name

* do not run triggers in go routines, they seem to deadlock

* initialize the trigger handlers late with a sync.OnceValue

* Revert "do not run triggers in go routines, they seem to deadlock"

This reverts commit 2a03da2127.

* add missing translations

* chore: update go version for linting

* pin oidc version

* parse a global time location for query test

* fix linter complains

* upgrade go lint

* fix more linting issues

---------

Co-authored-by: Stefan Benz <46600784+stebenz@users.noreply.github.com>
This commit is contained in:
Tim Möhlmann
2023-11-21 14:11:38 +02:00
committed by GitHub
parent ad3563d58b
commit ba9b807854
103 changed files with 3528 additions and 808 deletions

View File

@@ -0,0 +1,104 @@
package oidc
import (
"context"
"errors"
"strings"
"time"
"github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/op"
"github.com/zitadel/zitadel/internal/command"
errz "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/user/model"
)
type accessToken struct {
tokenID string
userID string
subject string
clientID string
audience []string
scope []string
tokenCreation time.Time
tokenExpiration time.Time
isPAT bool
}
func (s *Server) verifyAccessToken(ctx context.Context, tkn string) (*accessToken, error) {
var tokenID, subject string
if tokenIDSubject, err := s.Provider().Crypto().Decrypt(tkn); err == nil {
split := strings.Split(tokenIDSubject, ":")
if len(split) != 2 {
return nil, errors.New("invalid token format")
}
tokenID, subject = split[0], split[1]
} else {
verifier := op.NewAccessTokenVerifier(op.IssuerFromContext(ctx), s.keySet)
claims, err := op.VerifyAccessToken[*oidc.AccessTokenClaims](ctx, tkn, verifier)
if err != nil {
return nil, err
}
tokenID, subject = claims.JWTID, claims.Subject
}
if strings.HasPrefix(tokenID, command.IDPrefixV2) {
token, err := s.query.ActiveAccessTokenByToken(ctx, tokenID)
if err != nil {
return nil, err
}
return accessTokenV2(tokenID, subject, token), nil
}
token, err := s.repo.TokenByIDs(ctx, subject, tokenID)
if err != nil {
return nil, errz.ThrowPermissionDenied(err, "OIDC-Dsfb2", "token is not valid or has expired")
}
return accessTokenV1(tokenID, subject, token), nil
}
func accessTokenV1(tokenID, subject string, token *model.TokenView) *accessToken {
return &accessToken{
tokenID: tokenID,
userID: token.UserID,
subject: subject,
clientID: token.ApplicationID,
audience: token.Audience,
scope: token.Scopes,
tokenCreation: token.CreationDate,
tokenExpiration: token.Expiration,
isPAT: token.IsPAT,
}
}
func accessTokenV2(tokenID, subject string, token *query.OIDCSessionAccessTokenReadModel) *accessToken {
return &accessToken{
tokenID: tokenID,
userID: token.UserID,
subject: subject,
clientID: token.ClientID,
audience: token.Audience,
scope: token.Scope,
tokenCreation: token.AccessTokenCreation,
tokenExpiration: token.AccessTokenExpiration,
}
}
func (s *Server) assertClientScopesForPAT(ctx context.Context, token *accessToken, clientID, projectID string) error {
token.audience = append(token.audience, clientID)
projectIDQuery, err := query.NewProjectRoleProjectIDSearchQuery(projectID)
if err != nil {
return errz.ThrowInternal(err, "OIDC-Cyc78", "Errors.Internal")
}
roles, err := s.query.SearchProjectRoles(ctx, s.features.TriggerIntrospectionProjections, &query.ProjectRoleSearchQueries{Queries: []query.SearchQuery{projectIDQuery}})
if err != nil {
return err
}
for _, role := range roles.ProjectRoles {
token.scope = append(token.scope, ScopeProjectRolePrefix+role.Key)
}
return nil
}

View File

@@ -102,7 +102,7 @@ func (o *OPStorage) audienceFromProjectID(ctx context.Context, projectID string)
if err != nil {
return nil, err
}
appIDs, err := o.query.SearchClientIDs(ctx, &query.AppSearchQueries{Queries: []query.SearchQuery{projectIDQuery}}, false)
appIDs, err := o.query.SearchClientIDs(ctx, &query.AppSearchQueries{Queries: []query.SearchQuery{projectIDQuery}})
if err != nil {
return nil, err
}
@@ -432,7 +432,7 @@ func (o *OPStorage) assertProjectRoleScopes(ctx context.Context, clientID string
return scopes, nil
}
}
projectID, err := o.query.ProjectIDFromOIDCClientID(ctx, clientID, false)
projectID, err := o.query.ProjectIDFromOIDCClientID(ctx, clientID)
if err != nil {
return nil, errors.ThrowPreconditionFailed(nil, "OIDC-AEG4d", "Errors.Internal")
}

View File

@@ -43,7 +43,7 @@ const (
func (o *OPStorage) GetClientByClientID(ctx context.Context, id string) (_ op.Client, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
client, err := o.query.AppByOIDCClientID(ctx, id, false)
client, err := o.query.AppByOIDCClientID(ctx, id)
if err != nil {
return nil, err
}
@@ -94,7 +94,7 @@ func (o *OPStorage) GetKeyByIDAndIssuer(ctx context.Context, keyID, issuer strin
}
func (o *OPStorage) ValidateJWTProfileScopes(ctx context.Context, subject string, scopes []string) ([]string, error) {
user, err := o.query.GetUserByID(ctx, true, subject, false)
user, err := o.query.GetUserByID(ctx, true, subject)
if err != nil {
return nil, err
}
@@ -108,7 +108,7 @@ func (o *OPStorage) AuthorizeClientIDSecret(ctx context.Context, id string, secr
UserID: oidcCtx,
OrgID: oidcCtx,
})
app, err := o.query.AppByClientID(ctx, id, false)
app, err := o.query.AppByClientID(ctx, id)
if err != nil {
return err
}
@@ -149,7 +149,7 @@ func (o *OPStorage) SetUserinfoFromScopes(ctx context.Context, userInfo *oidc.Us
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
if applicationID != "" {
app, err := o.query.AppByOIDCClientID(ctx, applicationID, false)
app, err := o.query.AppByOIDCClientID(ctx, applicationID)
if err != nil {
return err
}
@@ -184,7 +184,7 @@ func (o *OPStorage) SetIntrospectionFromToken(ctx context.Context, introspection
if err != nil {
return err
}
projectID, err := o.query.ProjectIDFromClientID(ctx, clientID, false)
projectID, err := o.query.ProjectIDFromClientID(ctx, clientID)
if err != nil {
return errors.ThrowPermissionDenied(nil, "OIDC-Adfg5", "client not found")
}
@@ -198,7 +198,7 @@ func (o *OPStorage) SetIntrospectionFromToken(ctx context.Context, introspection
if err != nil {
return errors.ThrowPermissionDenied(nil, "OIDC-Dsfb2", "token is not valid or has expired")
}
projectID, err := o.query.ProjectIDFromClientID(ctx, clientID, false)
projectID, err := o.query.ProjectIDFromClientID(ctx, clientID)
if err != nil {
return errors.ThrowPermissionDenied(nil, "OIDC-Adfg5", "client not found")
}
@@ -219,7 +219,7 @@ func (o *OPStorage) ClientCredentialsTokenRequest(ctx context.Context, clientID
if err != nil {
return nil, err
}
user, err := o.query.GetUser(ctx, false, false, loginname)
user, err := o.query.GetUser(ctx, false, loginname)
if err != nil {
return nil, err
}
@@ -240,7 +240,7 @@ func (o *OPStorage) ClientCredentials(ctx context.Context, clientID, clientSecre
if err != nil {
return nil, err
}
user, err := o.query.GetUser(ctx, false, false, loginname)
user, err := o.query.GetUser(ctx, false, loginname)
if err != nil {
return nil, err
}
@@ -259,7 +259,7 @@ func (o *OPStorage) isOriginAllowed(ctx context.Context, clientID, origin string
if origin == "" {
return nil
}
app, err := o.query.AppByOIDCClientID(ctx, clientID, false)
app, err := o.query.AppByOIDCClientID(ctx, clientID)
if err != nil {
return err
}
@@ -331,7 +331,7 @@ func (o *OPStorage) checkOrgScopes(ctx context.Context, user *query.User, scopes
func (o *OPStorage) setUserinfo(ctx context.Context, userInfo *oidc.UserInfo, userID, applicationID string, scopes []string, roleAudience []string) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
user, err := o.query.GetUserByID(ctx, true, userID, false)
user, err := o.query.GetUserByID(ctx, true, userID)
if err != nil {
return err
}
@@ -645,7 +645,7 @@ func (o *OPStorage) GetPrivateClaimsFromScopes(ctx context.Context, userID, clie
}
func (o *OPStorage) privateClaimsFlows(ctx context.Context, userID string, userGrants *query.UserGrants, claims map[string]interface{}) (map[string]interface{}, error) {
user, err := o.query.GetUserByID(ctx, true, userID, false)
user, err := o.query.GetUserByID(ctx, true, userID)
if err != nil {
return nil, err
}
@@ -764,7 +764,7 @@ func (o *OPStorage) assertRoles(ctx context.Context, userID, applicationID strin
if (applicationID == "" || len(requestedRoles) == 0) && len(roleAudience) == 0 {
return nil, nil, nil
}
projectID, err := o.query.ProjectIDFromClientID(ctx, applicationID, false)
projectID, err := o.query.ProjectIDFromClientID(ctx, applicationID)
// applicationID might contain a username (e.g. client credentials) -> ignore the not found
if err != nil && !errors.IsNotFound(err) {
return nil, nil, err
@@ -795,7 +795,7 @@ func (o *OPStorage) assertRoles(ctx context.Context, userID, applicationID strin
if len(requestedRoles) > 0 {
for _, requestedRole := range requestedRoles {
for _, grant := range grants.UserGrants {
checkGrantedRoles(roles, grant, requestedRole, grant.ProjectID == projectID)
checkGrantedRoles(roles, *grant, requestedRole, grant.ProjectID == projectID)
}
}
return grants, roles, nil
@@ -823,7 +823,7 @@ func (o *OPStorage) assertUserMetaData(ctx context.Context, userID string) (map[
}
func (o *OPStorage) assertUserResourceOwner(ctx context.Context, userID string) (map[string]string, error) {
user, err := o.query.GetUserByID(ctx, true, userID, false)
user, err := o.query.GetUserByID(ctx, true, userID)
if err != nil {
return nil, err
}
@@ -838,7 +838,7 @@ func (o *OPStorage) assertUserResourceOwner(ctx context.Context, userID string)
}, nil
}
func checkGrantedRoles(roles *projectsRoles, grant *query.UserGrant, requestedRole string, isRequested bool) {
func checkGrantedRoles(roles *projectsRoles, grant query.UserGrant, requestedRole string, isRequested bool) {
for _, grantedRole := range grant.Roles {
if requestedRole == grantedRole {
roles.Add(grant.ProjectID, grantedRole, grant.ResourceOwner, grant.OrgPrimaryDomain, isRequested)
@@ -854,6 +854,26 @@ type projectsRoles struct {
requestProjectID string
}
func newProjectRoles(projectID string, grants []query.UserGrant, requestedRoles []string) *projectsRoles {
roles := new(projectsRoles)
// if specific roles where requested, check if they are granted and append them in the roles list
if len(requestedRoles) > 0 {
for _, requestedRole := range requestedRoles {
for _, grant := range grants {
checkGrantedRoles(roles, grant, requestedRole, grant.ProjectID == projectID)
}
}
return roles
}
// no specific roles were requested, so convert any grants into roles
for _, grant := range grants {
for _, role := range grant.Roles {
roles.Add(grant.ProjectID, role, grant.ResourceOwner, grant.OrgPrimaryDomain, grant.ProjectID == projectID)
}
}
return roles
}
func (p *projectsRoles) Add(projectID, roleKey, orgID, domain string, isRequested bool) {
if p.projects == nil {
p.projects = make(map[string]projectRoles, 1)

View File

@@ -48,7 +48,7 @@ func TestOPStorage_SetUserinfoFromToken(t *testing.T) {
assertUserinfo(t, userinfo)
}
func TestOPStorage_SetIntrospectionFromToken(t *testing.T) {
func TestServer_Introspect(t *testing.T) {
project, err := Tester.CreateProject(CTX)
require.NoError(t, err)
app, err := Tester.CreateOIDCNativeClient(CTX, redirectURI, logoutRedirectURI, project.GetId())

View File

@@ -0,0 +1,200 @@
package oidc
import (
"context"
"database/sql"
"errors"
"slices"
"time"
"github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/op"
"github.com/zitadel/zitadel/internal/crypto"
errz "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
func (s *Server) Introspect(ctx context.Context, r *op.Request[op.IntrospectionRequest]) (resp *op.Response, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
if s.features.LegacyIntrospection {
return s.LegacyServer.Introspect(ctx, r)
}
if s.features.TriggerIntrospectionProjections {
// Execute all triggers in one concurrent sweep.
query.TriggerIntrospectionProjections(ctx)
}
ctx, cancel := context.WithCancel(ctx)
defer cancel()
clientChan := make(chan *instrospectionClientResult)
go s.instrospectionClientAuth(ctx, r.Data.ClientCredentials, clientChan)
tokenChan := make(chan *introspectionTokenResult)
go s.introspectionToken(ctx, r.Data.Token, tokenChan)
var (
client *instrospectionClientResult
token *introspectionTokenResult
)
// make sure both channels are always read,
// and cancel the context on first error
for i := 0; i < 2; i++ {
var resErr error
select {
case client = <-clientChan:
resErr = client.err
case token = <-tokenChan:
resErr = token.err
}
if resErr == nil {
continue
}
cancel()
// we only care for the first error that occurred,
// as the next error is most probably a context error.
if err == nil {
err = resErr
}
}
// only client auth errors should be returned
var target *oidc.Error
if errors.As(err, &target) && target.ErrorType == oidc.UnauthorizedClient {
return nil, err
}
// remaining errors shoudn't be returned to the client,
// so we catch errors here, log them and return the response
// with active: false
defer func() {
if err != nil {
s.getLogger(ctx).ErrorContext(ctx, "oidc introspection", "err", err)
resp, err = op.NewResponse(new(oidc.IntrospectionResponse)), nil
}
}()
if err != nil {
return nil, err
}
// TODO: can we get rid of this separate query?
if token.isPAT {
if err = s.assertClientScopesForPAT(ctx, token.accessToken, client.clientID, client.projectID); err != nil {
return nil, err
}
}
if err = validateIntrospectionAudience(token.audience, client.clientID, client.projectID); err != nil {
return nil, err
}
userInfo, err := s.userInfo(ctx, token.userID, client.projectID, token.scope, []string{client.projectID})
if err != nil {
return nil, err
}
introspectionResp := &oidc.IntrospectionResponse{
Active: true,
Scope: token.scope,
ClientID: token.clientID,
TokenType: oidc.BearerToken,
Expiration: oidc.FromTime(token.tokenExpiration),
IssuedAt: oidc.FromTime(token.tokenCreation),
NotBefore: oidc.FromTime(token.tokenCreation),
Audience: token.audience,
Issuer: op.IssuerFromContext(ctx),
JWTID: token.tokenID,
}
introspectionResp.SetUserInfo(userInfo)
return op.NewResponse(introspectionResp), nil
}
type instrospectionClientResult struct {
clientID string
projectID string
err error
}
func (s *Server) instrospectionClientAuth(ctx context.Context, cc *op.ClientCredentials, rc chan<- *instrospectionClientResult) {
ctx, span := tracing.NewSpan(ctx)
clientID, projectID, err := func() (string, string, error) {
client, err := s.clientFromCredentials(ctx, cc)
if err != nil {
return "", "", err
}
if cc.ClientAssertion != "" {
verifier := op.NewJWTProfileVerifierKeySet(keySetMap(client.PublicKeys), op.IssuerFromContext(ctx), time.Hour, time.Second)
if _, err := op.VerifyJWTAssertion(ctx, cc.ClientAssertion, verifier); err != nil {
return "", "", oidc.ErrUnauthorizedClient().WithParent(err)
}
} else {
if err := crypto.CompareHash(client.ClientSecret, []byte(cc.ClientSecret), s.hashAlg); err != nil {
return "", "", oidc.ErrUnauthorizedClient().WithParent(err)
}
}
return client.ClientID, client.ProjectID, nil
}()
span.EndWithError(err)
rc <- &instrospectionClientResult{
clientID: clientID,
projectID: projectID,
err: err,
}
}
// clientFromCredentials parses the client ID early,
// and makes a single query for the client for either auth methods.
func (s *Server) clientFromCredentials(ctx context.Context, cc *op.ClientCredentials) (client *query.IntrospectionClient, err error) {
if cc.ClientAssertion != "" {
claims := new(oidc.JWTTokenRequest)
if _, err := oidc.ParseToken(cc.ClientAssertion, claims); err != nil {
return nil, oidc.ErrUnauthorizedClient().WithParent(err)
}
client, err = s.query.GetIntrospectionClientByID(ctx, claims.Issuer, true)
} else {
client, err = s.query.GetIntrospectionClientByID(ctx, cc.ClientID, false)
}
if errors.Is(err, sql.ErrNoRows) {
return nil, oidc.ErrUnauthorizedClient().WithParent(err)
}
// any other error is regarded internal and should not be reported back to the client.
return client, err
}
type introspectionTokenResult struct {
*accessToken
err error
}
func (s *Server) introspectionToken(ctx context.Context, tkn string, rc chan<- *introspectionTokenResult) {
ctx, span := tracing.NewSpan(ctx)
token, err := s.verifyAccessToken(ctx, tkn)
span.EndWithError(err)
rc <- &introspectionTokenResult{
accessToken: token,
err: err,
}
}
func validateIntrospectionAudience(audience []string, clientID, projectID string) error {
if slices.ContainsFunc(audience, func(entry string) bool {
return entry == clientID || entry == projectID
}) {
return nil
}
return errz.ThrowPermissionDenied(nil, "OIDC-sdg3G", "token is not valid for this client")
}

View File

@@ -12,7 +12,7 @@ import (
func (o *OPStorage) JWTProfileTokenType(ctx context.Context, request op.TokenRequest) (op.AccessTokenType, error) {
mapJWTProfileScopesToAudience(ctx, request)
user, err := o.query.GetUserByID(ctx, false, request.GetSubject(), false)
user, err := o.query.GetUserByID(ctx, false, request.GetSubject())
if err != nil {
return 0, err
}

View File

@@ -3,14 +3,17 @@ package oidc
import (
"context"
"fmt"
"sync"
"time"
"github.com/go-jose/go-jose/v3"
"github.com/jonboulle/clockwork"
"github.com/zitadel/logging"
"github.com/zitadel/oidc/v3/pkg/op"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/query"
@@ -19,6 +22,145 @@ import (
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
// keySetCache implements oidc.KeySet for Access Token verification.
// Public Keys are cached in a 2-dimensional map of Instance ID and Key ID.
// When a key is not present the queryKey function is called to obtain the key
// from the database.
type keySetCache struct {
mtx sync.RWMutex
instanceKeys map[string]map[string]query.PublicKey
queryKey func(ctx context.Context, keyID string, current time.Time) (query.PublicKey, error)
clock clockwork.Clock
}
// newKeySet initializes a keySetCache and starts a purging Go routine,
// which runs once every purgeInterval.
// When the passed context is done, the purge routine will terminate.
func newKeySet(background context.Context, purgeInterval time.Duration, queryKey func(ctx context.Context, keyID string, current time.Time) (query.PublicKey, error)) *keySetCache {
k := &keySetCache{
instanceKeys: make(map[string]map[string]query.PublicKey),
queryKey: queryKey,
clock: clockwork.FromContext(background), // defaults to real clock
}
go k.purgeOnInterval(background, k.clock.NewTicker(purgeInterval))
return k
}
func (k *keySetCache) purgeOnInterval(background context.Context, ticker clockwork.Ticker) {
defer ticker.Stop()
for {
select {
case <-background.Done():
return
case <-ticker.Chan():
}
// do the actual purging
k.mtx.Lock()
for instanceID, keys := range k.instanceKeys {
for keyID, key := range keys {
if key.Expiry().Before(k.clock.Now()) {
delete(keys, keyID)
}
}
if len(keys) == 0 {
delete(k.instanceKeys, instanceID)
}
}
k.mtx.Unlock()
}
}
func (k *keySetCache) setKey(instanceID, keyID string, key query.PublicKey) {
k.mtx.Lock()
defer k.mtx.Unlock()
if keys, ok := k.instanceKeys[instanceID]; ok {
keys[keyID] = key
return
}
k.instanceKeys[instanceID] = map[string]query.PublicKey{keyID: key}
}
func (k *keySetCache) getKey(ctx context.Context, keyID string) (_ *jose.JSONWebKey, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
instanceID := authz.GetInstance(ctx).InstanceID()
k.mtx.RLock()
key, ok := k.instanceKeys[instanceID][keyID]
k.mtx.RUnlock()
if ok {
if key.Expiry().After(k.clock.Now()) {
return jsonWebkey(key), nil
}
return nil, errors.ThrowInvalidArgument(nil, "OIDC-Zoh9E", "Errors.Key.ExpireBeforeNow")
}
key, err = k.queryKey(ctx, keyID, k.clock.Now())
if err != nil {
return nil, err
}
k.setKey(instanceID, keyID, key)
return jsonWebkey(key), nil
}
// VerifySignature implements the oidc.KeySet interface.
func (k *keySetCache) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) (_ []byte, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
if len(jws.Signatures) != 1 {
return nil, errors.ThrowInvalidArgument(nil, "OIDC-Gid9s", "Errors.Token.Invalid")
}
key, err := k.getKey(ctx, jws.Signatures[0].Header.KeyID)
if err != nil {
return nil, err
}
return jws.Verify(key)
}
func jsonWebkey(key query.PublicKey) *jose.JSONWebKey {
return &jose.JSONWebKey{
KeyID: key.ID(),
Algorithm: key.Algorithm(),
Use: key.Use().String(),
Key: key.Key(),
}
}
// keySetMap is a mapping of key IDs to public key data.
type keySetMap map[string][]byte
// getKey finds the keyID and parses the public key data
// into a JSONWebKey.
func (k keySetMap) getKey(keyID string) (*jose.JSONWebKey, error) {
pubKey, err := crypto.BytesToPublicKey(k[keyID])
if err != nil {
return nil, err
}
return &jose.JSONWebKey{
Key: pubKey,
KeyID: keyID,
Use: domain.KeyUsageSigning.String(),
}, nil
}
// VerifySignature implements the oidc.KeySet interface.
func (k keySetMap) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) {
if len(jws.Signatures) != 1 {
return nil, errors.ThrowInvalidArgument(nil, "OIDC-Eeth6", "Errors.Token.Invalid")
}
key, err := k.getKey(jws.Signatures[0].Header.KeyID)
if err != nil {
return nil, err
}
return jws.Verify(key)
}
const (
locksTable = "projections.locks"
signingKey = "signing_key"

View File

@@ -0,0 +1,244 @@
package oidc
import (
"context"
"errors"
"testing"
"time"
"github.com/go-jose/go-jose/v3"
"github.com/jonboulle/clockwork"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/query"
)
type publicKey struct {
id string
alg string
use domain.KeyUsage
seq uint64
expiry time.Time
key any
}
func (k *publicKey) ID() string {
return k.id
}
func (k *publicKey) Algorithm() string {
return k.alg
}
func (k *publicKey) Use() domain.KeyUsage {
return k.use
}
func (k *publicKey) Sequence() uint64 {
return k.seq
}
func (k *publicKey) Expiry() time.Time {
return k.expiry
}
func (k *publicKey) Key() any {
return k.key
}
var (
clock = clockwork.NewFakeClock()
keyDB = map[string]*publicKey{
"key1": {
id: "key1",
alg: "alg",
use: domain.KeyUsageSigning,
seq: 1,
expiry: clock.Now().Add(time.Minute),
},
"key2": {
id: "key2",
alg: "alg",
use: domain.KeyUsageSigning,
seq: 3,
expiry: clock.Now().Add(10 * time.Hour),
},
}
)
func queryKeyDB(_ context.Context, keyID string, current time.Time) (query.PublicKey, error) {
if key, ok := keyDB[keyID]; ok {
return key, nil
}
return nil, errors.New("not found")
}
func Test_keySetCache(t *testing.T) {
background, cancel := context.WithCancel(
clockwork.AddToContext(context.Background(), clock),
)
defer cancel()
// create an empty keySet with a purge go routine, runs every Hour
keySet := newKeySet(background, time.Hour, queryKeyDB)
ctx := authz.NewMockContext("instanceID", "orgID", "userID")
// query error
_, err := keySet.getKey(ctx, "key9")
require.Error(t, err)
want := &jose.JSONWebKey{
KeyID: "key1",
Algorithm: "alg",
Use: domain.KeyUsageSigning.String(),
}
// get key first time, populate the cache
got, err := keySet.getKey(ctx, "key1")
require.NoError(t, err)
assert.Equal(t, want, got)
// move time forward
clock.Advance(5 * time.Minute)
time.Sleep(time.Millisecond)
// key should still be in cache
keySet.mtx.RLock()
_, ok := keySet.instanceKeys["instanceID"]["key1"]
require.True(t, ok)
keySet.mtx.RUnlock()
// the key is expired, should error
_, err = keySet.getKey(ctx, "key1")
require.Error(t, err)
want = &jose.JSONWebKey{
KeyID: "key2",
Algorithm: "alg",
Use: domain.KeyUsageSigning.String(),
}
// get the second key from DB
got, err = keySet.getKey(ctx, "key2")
require.NoError(t, err)
assert.Equal(t, want, got)
// move time forward
clock.Advance(time.Hour)
time.Sleep(time.Millisecond)
// first key shoud be purged, second still present
keySet.mtx.RLock()
_, ok = keySet.instanceKeys["instanceID"]["key1"]
require.False(t, ok)
_, ok = keySet.instanceKeys["instanceID"]["key2"]
require.True(t, ok)
keySet.mtx.RUnlock()
// get the second key from cache
got, err = keySet.getKey(ctx, "key2")
require.NoError(t, err)
assert.Equal(t, want, got)
// move time forward
clock.Advance(10 * time.Hour)
time.Sleep(time.Millisecond)
// now the cache should be empty
keySet.mtx.RLock()
assert.Empty(t, keySet.instanceKeys)
keySet.mtx.RUnlock()
}
func Test_keySetCache_VerifySignature(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
k := newKeySet(ctx, time.Second, queryKeyDB)
tests := []struct {
name string
jws *jose.JSONWebSignature
}{
{
name: "invalid token",
jws: &jose.JSONWebSignature{},
},
{
name: "key not found",
jws: &jose.JSONWebSignature{
Signatures: []jose.Signature{{
Header: jose.Header{
KeyID: "xxx",
},
}},
},
},
{
name: "verify error",
jws: &jose.JSONWebSignature{
Signatures: []jose.Signature{{
Header: jose.Header{
KeyID: "key1",
},
}},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := k.VerifySignature(ctx, tt.jws)
require.Error(t, err)
})
}
}
func Test_keySetMap_VerifySignature(t *testing.T) {
tests := []struct {
name string
k keySetMap
jws *jose.JSONWebSignature
}{
{
name: "invalid signature",
k: keySetMap{
"key1": []byte("foo"),
},
jws: &jose.JSONWebSignature{},
},
{
name: "parse error",
k: keySetMap{
"key1": []byte("foo"),
},
jws: &jose.JSONWebSignature{
Signatures: []jose.Signature{{
Header: jose.Header{
KeyID: "key1",
},
}},
},
},
{
name: "verify error",
k: keySetMap{
"key1": []byte("-----BEGIN RSA PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAsvX9P58JFxEs5C+L+H7W\nduFSWL5EPzber7C2m94klrSV6q0bAcrYQnGwFOlveThsY200hRbadKaKjHD7qIKH\nDEe0IY2PSRht33Jye52AwhkRw+M3xuQH/7R8LydnsNFk2KHpr5X2SBv42e37LjkE\nslKSaMRgJW+v0KZ30piY8QsdFRKKaVg5/Ajt1YToM1YVsdHXJ3vmXFMtypLdxwUD\ndIaLEX6pFUkU75KSuEQ/E2luT61Q3ta9kOWm9+0zvi7OMcbdekJT7mzcVnh93R1c\n13ZhQCLbh9A7si8jKFtaMWevjayrvqQABEcTN9N4Hoxcyg6l4neZtRDk75OMYcqm\nDQIDAQAB\n-----END RSA PUBLIC KEY-----\n"),
},
jws: &jose.JSONWebSignature{
Signatures: []jose.Signature{{
Header: jose.Header{
KeyID: "key1",
},
}},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := tt.k.VerifySignature(context.Background(), tt.jws)
require.Error(t, err)
})
}
}

View File

@@ -45,6 +45,7 @@ type Config struct {
DeviceAuth *DeviceAuthorizationConfig
DefaultLoginURLV2 string
DefaultLogoutURLV2 string
Features Features
}
type EndpointConfig struct {
@@ -63,6 +64,11 @@ type Endpoint struct {
URL string
}
type Features struct {
TriggerIntrospectionProjections bool
LegacyIntrospection bool
}
type OPStorage struct {
repo repository.Repository
command *command.Commands
@@ -120,7 +126,15 @@ func NewServer(
server := &Server{
LegacyServer: op.NewLegacyServer(provider, endpoints(config.CustomEndpoints)),
features: config.Features,
repo: repo,
query: query,
command: command,
keySet: newKeySet(context.TODO(), time.Hour, query.GetActivePublicKeyByID),
fallbackLogger: fallbackLogger,
hashAlg: crypto.NewBCrypt(10), // as we are only verifying in oidc, the cost is already part of the hash string and the config here is irrelevant.
signingKeyAlgorithm: config.SigningKeyAlgorithm,
assetAPIPrefix: assets.AssetAPI(externalSecure),
}
metricTypes := []metrics.MetricType{metrics.MetricTypeRequestCount, metrics.MetricTypeStatusCode, metrics.MetricTypeTotalCount}
server.Handler = op.RegisterLegacyServer(server, op.WithHTTPMiddleware(

View File

@@ -4,16 +4,32 @@ import (
"context"
"net/http"
"github.com/zitadel/logging"
"github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/op"
"golang.org/x/exp/slog"
"github.com/zitadel/zitadel/internal/auth/repository"
"github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
type Server struct {
http.Handler
*op.LegacyServer
features Features
repo repository.Repository
query *query.Queries
command *command.Commands
keySet *keySetCache
fallbackLogger *slog.Logger
hashAlg crypto.HashAlgorithm
signingKeyAlgorithm string
assetAPIPrefix func(ctx context.Context) string
}
func endpoints(endpointConfig *EndpointConfig) op.Endpoints {
@@ -59,6 +75,13 @@ func endpoints(endpointConfig *EndpointConfig) op.Endpoints {
return endpoints
}
func (s *Server) getLogger(ctx context.Context) *slog.Logger {
if logger, ok := logging.FromContext(ctx); ok {
return logger
}
return s.fallbackLogger
}
func (s *Server) IssuerFromRequest(r *http.Request) string {
return s.Provider().IssuerFromRequest(r)
}
@@ -161,13 +184,6 @@ func (s *Server) DeviceToken(ctx context.Context, r *op.ClientRequest[oidc.Devic
return s.LegacyServer.DeviceToken(ctx, r)
}
func (s *Server) Introspect(ctx context.Context, r *op.Request[op.IntrospectionRequest]) (_ *op.Response, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
return s.LegacyServer.Introspect(ctx, r)
}
func (s *Server) UserInfo(ctx context.Context, r *op.Request[oidc.UserInfoRequest]) (_ *op.Response, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()

View File

@@ -0,0 +1,276 @@
package oidc
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"slices"
"strings"
"github.com/dop251/goja"
"github.com/zitadel/logging"
"github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/zitadel/internal/actions"
"github.com/zitadel/zitadel/internal/actions/object"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/query"
)
func (s *Server) userInfo(ctx context.Context, userID, projectID string, scope, roleAudience []string) (_ *oidc.UserInfo, err error) {
roleAudience, requestedRoles := prepareRoles(ctx, projectID, scope, roleAudience)
qu, err := s.query.GetOIDCUserInfo(ctx, userID, roleAudience)
if err != nil {
return nil, err
}
userInfo := userInfoToOIDC(projectID, qu, scope, roleAudience, requestedRoles, s.assetAPIPrefix(ctx))
return userInfo, s.userinfoFlows(ctx, qu, userInfo)
}
// prepareRoles scans the requested scopes, appends to roleAudiendce and returns the requestedRoles.
//
// When [ScopeProjectsRoles] is present and roleAudience was empty,
// project IDs with the [domain.ProjectIDScope] prefix are added to the roleAudience.
//
// Scopes with [ScopeProjectRolePrefix] are added to requestedRoles.
//
// If the resulting requestedRoles or roleAudience are not not empty,
// the current projectID will always be parts or roleAudience.
// Else nil, nil is returned.
func prepareRoles(ctx context.Context, projectID string, scope, roleAudience []string) (ra, requestedRoles []string) {
// if all roles are requested take the audience for those from the scopes
if slices.Contains(scope, ScopeProjectsRoles) && len(roleAudience) == 0 {
roleAudience = domain.AddAudScopeToAudience(ctx, roleAudience, scope)
}
requestedRoles = make([]string, 0, len(scope))
for _, s := range scope {
if role, ok := strings.CutPrefix(s, ScopeProjectRolePrefix); ok {
requestedRoles = append(requestedRoles, role)
}
}
if len(requestedRoles) == 0 && len(roleAudience) == 0 {
return nil, nil
}
if projectID != "" && !slices.Contains(roleAudience, projectID) {
roleAudience = append(roleAudience, projectID)
}
return roleAudience, requestedRoles
}
func userInfoToOIDC(projectID string, user *query.OIDCUserInfo, scope, roleAudience, requestedRoles []string, assetPrefix string) *oidc.UserInfo {
out := new(oidc.UserInfo)
for _, s := range scope {
switch s {
case oidc.ScopeOpenID:
out.Subject = user.User.ID
case oidc.ScopeEmail:
out.UserInfoEmail = userInfoEmailToOIDC(user.User)
case oidc.ScopeProfile:
out.UserInfoProfile = userInfoProfileToOidc(user.User, assetPrefix)
case oidc.ScopePhone:
out.UserInfoPhone = userInfoPhoneToOIDC(user.User)
case oidc.ScopeAddress:
//TODO: handle address for human users as soon as implemented
case ScopeUserMetaData:
setUserInfoMetadata(user.Metadata, out)
case ScopeResourceOwner:
setUserInfoOrgClaims(user, out)
default:
if claim, ok := strings.CutPrefix(s, domain.OrgDomainPrimaryScope); ok {
out.AppendClaims(domain.OrgDomainPrimaryClaim, claim)
}
if claim, ok := strings.CutPrefix(s, domain.OrgIDScope); ok {
out.AppendClaims(domain.OrgIDClaim, claim)
setUserInfoOrgClaims(user, out)
}
}
}
// prevent returning obtained grants if none where requested
if (projectID != "" && len(requestedRoles) > 0) || len(roleAudience) > 0 {
setUserInfoRoleClaims(out, newProjectRoles(projectID, user.UserGrants, requestedRoles))
}
return out
}
func userInfoEmailToOIDC(user *query.User) oidc.UserInfoEmail {
if human := user.Human; human != nil {
return oidc.UserInfoEmail{
Email: string(human.Email),
EmailVerified: oidc.Bool(human.IsEmailVerified),
}
}
return oidc.UserInfoEmail{}
}
func userInfoProfileToOidc(user *query.User, assetPrefix string) oidc.UserInfoProfile {
if human := user.Human; human != nil {
return oidc.UserInfoProfile{
Name: human.DisplayName,
GivenName: human.FirstName,
FamilyName: human.LastName,
Nickname: human.NickName,
Picture: domain.AvatarURL(assetPrefix, user.ResourceOwner, user.Human.AvatarKey),
Gender: getGender(human.Gender),
Locale: oidc.NewLocale(human.PreferredLanguage),
UpdatedAt: oidc.FromTime(user.ChangeDate),
PreferredUsername: user.PreferredLoginName,
}
}
if machine := user.Machine; machine != nil {
return oidc.UserInfoProfile{
Name: machine.Name,
UpdatedAt: oidc.FromTime(user.ChangeDate),
PreferredUsername: user.PreferredLoginName,
}
}
return oidc.UserInfoProfile{}
}
func userInfoPhoneToOIDC(user *query.User) oidc.UserInfoPhone {
if human := user.Human; human != nil {
return oidc.UserInfoPhone{
PhoneNumber: string(human.Phone),
PhoneNumberVerified: human.IsPhoneVerified,
}
}
return oidc.UserInfoPhone{}
}
func setUserInfoMetadata(metadata []query.UserMetadata, out *oidc.UserInfo) {
if len(metadata) == 0 {
return
}
mdmap := make(map[string]string, len(metadata))
for _, md := range metadata {
mdmap[md.Key] = base64.RawURLEncoding.EncodeToString(md.Value)
}
out.AppendClaims(ClaimUserMetaData, mdmap)
}
func setUserInfoOrgClaims(user *query.OIDCUserInfo, out *oidc.UserInfo) {
if org := user.Org; org != nil {
out.AppendClaims(ClaimResourceOwner+"id", org.ID)
out.AppendClaims(ClaimResourceOwner+"name", org.Name)
out.AppendClaims(ClaimResourceOwner+"primary_domain", org.PrimaryDomain)
}
}
func setUserInfoRoleClaims(userInfo *oidc.UserInfo, roles *projectsRoles) {
if roles != nil && len(roles.projects) > 0 {
if roles, ok := roles.projects[roles.requestProjectID]; ok {
userInfo.AppendClaims(ClaimProjectRoles, roles)
}
for projectID, roles := range roles.projects {
userInfo.AppendClaims(fmt.Sprintf(ClaimProjectRolesFormat, projectID), roles)
}
}
}
func (s *Server) userinfoFlows(ctx context.Context, qu *query.OIDCUserInfo, userInfo *oidc.UserInfo) error {
queriedActions, err := s.query.GetActiveActionsByFlowAndTriggerType(ctx, domain.FlowTypeCustomiseToken, domain.TriggerTypePreUserinfoCreation, qu.User.ResourceOwner)
if err != nil {
return err
}
ctxFields := actions.SetContextFields(
actions.SetFields("v1",
actions.SetFields("claims", userinfoClaims(userInfo)),
actions.SetFields("getUser", func(c *actions.FieldConfig) interface{} {
return func(call goja.FunctionCall) goja.Value {
return object.UserFromQuery(c, qu.User)
}
}),
actions.SetFields("user",
actions.SetFields("getMetadata", func(c *actions.FieldConfig) interface{} {
return func(goja.FunctionCall) goja.Value {
return object.UserMetadataListFromSlice(c, qu.Metadata)
}
}),
actions.SetFields("grants", func(c *actions.FieldConfig) interface{} {
return object.UserGrantsFromSlice(c, qu.UserGrants)
}),
),
),
)
for _, action := range queriedActions {
actionCtx, cancel := context.WithTimeout(ctx, action.Timeout())
claimLogs := []string{}
apiFields := actions.WithAPIFields(
actions.SetFields("v1",
actions.SetFields("userinfo",
actions.SetFields("setClaim", func(key string, value interface{}) {
if userInfo.Claims[key] == nil {
userInfo.AppendClaims(key, value)
return
}
claimLogs = append(claimLogs, fmt.Sprintf("key %q already exists", key))
}),
actions.SetFields("appendLogIntoClaims", func(entry string) {
claimLogs = append(claimLogs, entry)
}),
),
actions.SetFields("claims",
actions.SetFields("setClaim", func(key string, value interface{}) {
if userInfo.Claims[key] == nil {
userInfo.AppendClaims(key, value)
return
}
claimLogs = append(claimLogs, fmt.Sprintf("key %q already exists", key))
}),
actions.SetFields("appendLogIntoClaims", func(entry string) {
claimLogs = append(claimLogs, entry)
}),
),
actions.SetFields("user",
actions.SetFields("setMetadata", func(call goja.FunctionCall) goja.Value {
if len(call.Arguments) != 2 {
panic("exactly 2 (key, value) arguments expected")
}
key := call.Arguments[0].Export().(string)
val := call.Arguments[1].Export()
value, err := json.Marshal(val)
if err != nil {
logging.WithError(err).Debug("unable to marshal")
panic(err)
}
metadata := &domain.Metadata{
Key: key,
Value: value,
}
if _, err = s.command.SetUserMetadata(ctx, metadata, userInfo.Subject, qu.User.ResourceOwner); err != nil {
logging.WithError(err).Info("unable to set md in action")
panic(err)
}
return nil
}),
),
),
)
err = actions.Run(
actionCtx,
ctxFields,
apiFields,
action.Script,
action.Name,
append(actions.ActionToOptions(action), actions.WithHTTP(actionCtx), actions.WithUUID(actionCtx))...,
)
cancel()
if err != nil {
return err
}
if len(claimLogs) > 0 {
userInfo.AppendClaims(fmt.Sprintf(ClaimActionLogFormat, action.Name), claimLogs)
}
}
return nil
}

View File

@@ -0,0 +1,434 @@
package oidc
import (
"context"
"encoding/base64"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/zitadel/oidc/v3/pkg/oidc"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/query"
)
func Test_prepareRoles(t *testing.T) {
type args struct {
projectID string
scope []string
roleAudience []string
}
tests := []struct {
name string
args args
wantRa []string
wantRequestedRoles []string
}{
{
name: "empty scope and roleAudience",
args: args{
projectID: "projID",
scope: nil,
roleAudience: nil,
},
wantRa: nil,
wantRequestedRoles: nil,
},
{
name: "some scope and roleAudience",
args: args{
projectID: "projID",
scope: []string{"openid", "profile"},
roleAudience: []string{"project2"},
},
wantRa: []string{"project2", "projID"},
wantRequestedRoles: []string{},
},
{
name: "scope projects roles",
args: args{
projectID: "projID",
scope: []string{ScopeProjectsRoles, domain.ProjectIDScope + "project2" + domain.AudSuffix},
roleAudience: nil,
},
wantRa: []string{"project2", "projID"},
wantRequestedRoles: []string{},
},
{
name: "scope project role prefix",
args: args{
projectID: "projID",
scope: []string{"openid", "profile", ScopeProjectRolePrefix + "foo", ScopeProjectRolePrefix + "bar"},
roleAudience: nil,
},
wantRa: []string{"projID"},
wantRequestedRoles: []string{"foo", "bar"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotRa, gotRequestedRoles := prepareRoles(context.Background(), tt.args.projectID, tt.args.scope, tt.args.roleAudience)
assert.Equal(t, tt.wantRa, gotRa, "roleAudience")
assert.Equal(t, tt.wantRequestedRoles, gotRequestedRoles, "requestedRoles")
})
}
}
func Test_userInfoToOIDC(t *testing.T) {
metadata := []query.UserMetadata{
{
Key: "key1",
Value: []byte{1, 2, 3},
},
{
Key: "key2",
Value: []byte{4, 5, 6},
},
}
organization := &query.UserInfoOrg{
ID: "orgID",
Name: "orgName",
PrimaryDomain: "orgDomain",
}
humanUserInfo := &query.OIDCUserInfo{
User: &query.User{
ID: "human1",
CreationDate: time.Unix(123, 456),
ChangeDate: time.Unix(567, 890),
ResourceOwner: "orgID",
Sequence: 22,
State: domain.UserStateActive,
Type: domain.UserTypeHuman,
Username: "username",
LoginNames: []string{"foo", "bar"},
PreferredLoginName: "foo",
Human: &query.Human{
FirstName: "user",
LastName: "name",
NickName: "foobar",
DisplayName: "xxx",
AvatarKey: "picture.png",
PreferredLanguage: language.Dutch,
Gender: domain.GenderDiverse,
Email: "foo@bar.com",
IsEmailVerified: true,
Phone: "+31123456789",
IsPhoneVerified: true,
},
},
Metadata: metadata,
Org: organization,
UserGrants: []query.UserGrant{
{
ID: "ug1",
CreationDate: time.Unix(444, 444),
ChangeDate: time.Unix(555, 555),
Sequence: 55,
Roles: []string{"role1", "role2"},
GrantID: "grantID",
State: domain.UserGrantStateActive,
UserID: "human1",
Username: "username",
ResourceOwner: "orgID",
ProjectID: "project1",
OrgName: "orgName",
OrgPrimaryDomain: "orgDomain",
ProjectName: "projectName",
UserResourceOwner: "org1",
},
},
}
machineUserInfo := &query.OIDCUserInfo{
User: &query.User{
ID: "machine1",
CreationDate: time.Unix(123, 456),
ChangeDate: time.Unix(567, 890),
ResourceOwner: "orgID",
Sequence: 23,
State: domain.UserStateActive,
Type: domain.UserTypeMachine,
Username: "machine",
PreferredLoginName: "meanMachine",
Machine: &query.Machine{
Name: "machine",
Description: "I'm a robot",
},
},
Org: organization,
UserGrants: []query.UserGrant{
{
ID: "ug1",
CreationDate: time.Unix(444, 444),
ChangeDate: time.Unix(555, 555),
Sequence: 55,
Roles: []string{"role1", "role2"},
GrantID: "grantID",
State: domain.UserGrantStateActive,
UserID: "human1",
Username: "username",
ResourceOwner: "orgID",
ProjectID: "project1",
OrgName: "orgName",
OrgPrimaryDomain: "orgDomain",
ProjectName: "projectName",
UserResourceOwner: "org1",
},
},
}
type args struct {
projectID string
user *query.OIDCUserInfo
scope []string
roleAudience []string
requestedRoles []string
}
tests := []struct {
name string
args args
want *oidc.UserInfo
}{
{
name: "human, empty",
args: args{
projectID: "project1",
user: humanUserInfo,
},
want: &oidc.UserInfo{},
},
{
name: "machine, empty",
args: args{
projectID: "project1",
user: machineUserInfo,
},
want: &oidc.UserInfo{},
},
{
name: "human, scope openid",
args: args{
projectID: "project1",
user: humanUserInfo,
scope: []string{oidc.ScopeOpenID},
},
want: &oidc.UserInfo{
Subject: "human1",
},
},
{
name: "machine, scope openid",
args: args{
projectID: "project1",
user: machineUserInfo,
scope: []string{oidc.ScopeOpenID},
},
want: &oidc.UserInfo{
Subject: "machine1",
},
},
{
name: "human, scope email",
args: args{
projectID: "project1",
user: humanUserInfo,
scope: []string{oidc.ScopeEmail},
},
want: &oidc.UserInfo{
UserInfoEmail: oidc.UserInfoEmail{
Email: "foo@bar.com",
EmailVerified: true,
},
},
},
{
name: "machine, scope email",
args: args{
projectID: "project1",
user: machineUserInfo,
scope: []string{oidc.ScopeEmail},
},
want: &oidc.UserInfo{
UserInfoEmail: oidc.UserInfoEmail{},
},
},
{
name: "human, scope profile",
args: args{
projectID: "project1",
user: humanUserInfo,
scope: []string{oidc.ScopeProfile},
},
want: &oidc.UserInfo{
UserInfoProfile: oidc.UserInfoProfile{
Name: "xxx",
GivenName: "user",
FamilyName: "name",
Nickname: "foobar",
Picture: "https://foo.com/assets/orgID/picture.png",
Gender: "diverse",
Locale: oidc.NewLocale(language.Dutch),
UpdatedAt: oidc.FromTime(time.Unix(567, 890)),
PreferredUsername: "foo",
},
},
},
{
name: "machine, scope profile",
args: args{
projectID: "project1",
user: machineUserInfo,
scope: []string{oidc.ScopeProfile},
},
want: &oidc.UserInfo{
UserInfoProfile: oidc.UserInfoProfile{
Name: "machine",
UpdatedAt: oidc.FromTime(time.Unix(567, 890)),
PreferredUsername: "meanMachine",
},
},
},
{
name: "human, scope phone",
args: args{
projectID: "project1",
user: humanUserInfo,
scope: []string{oidc.ScopePhone},
},
want: &oidc.UserInfo{
UserInfoPhone: oidc.UserInfoPhone{
PhoneNumber: "+31123456789",
PhoneNumberVerified: true,
},
},
},
{
name: "machine, scope phone",
args: args{
projectID: "project1",
user: machineUserInfo,
scope: []string{oidc.ScopePhone},
},
want: &oidc.UserInfo{
UserInfoPhone: oidc.UserInfoPhone{},
},
},
{
name: "human, scope metadata",
args: args{
projectID: "project1",
user: humanUserInfo,
scope: []string{ScopeUserMetaData},
},
want: &oidc.UserInfo{
Claims: map[string]any{
ClaimUserMetaData: map[string]string{
"key1": base64.RawURLEncoding.EncodeToString([]byte{1, 2, 3}),
"key2": base64.RawURLEncoding.EncodeToString([]byte{4, 5, 6}),
},
},
},
},
{
name: "machine, scope metadata, none found",
args: args{
projectID: "project1",
user: machineUserInfo,
scope: []string{ScopeUserMetaData},
},
want: &oidc.UserInfo{},
},
{
name: "machine, scope resource owner",
args: args{
projectID: "project1",
user: machineUserInfo,
scope: []string{ScopeResourceOwner},
},
want: &oidc.UserInfo{
Claims: map[string]any{
ClaimResourceOwner + "id": "orgID",
ClaimResourceOwner + "name": "orgName",
ClaimResourceOwner + "primary_domain": "orgDomain",
},
},
},
{
name: "human, scope org primary domain prefix",
args: args{
projectID: "project1",
user: humanUserInfo,
scope: []string{domain.OrgDomainPrimaryScope + "foo.com"},
},
want: &oidc.UserInfo{
Claims: map[string]any{
domain.OrgDomainPrimaryClaim: "foo.com",
},
},
},
{
name: "machine, scope org id",
args: args{
projectID: "project1",
user: machineUserInfo,
scope: []string{domain.OrgIDScope + "orgID"},
},
want: &oidc.UserInfo{
Claims: map[string]any{
domain.OrgIDClaim: "orgID",
ClaimResourceOwner + "id": "orgID",
ClaimResourceOwner + "name": "orgName",
ClaimResourceOwner + "primary_domain": "orgDomain",
},
},
},
{
name: "human, roleAudience",
args: args{
projectID: "project1",
user: humanUserInfo,
roleAudience: []string{"project1"},
},
want: &oidc.UserInfo{
Claims: map[string]any{
ClaimProjectRoles: projectRoles{
"role1": {"orgID": "orgDomain"},
"role2": {"orgID": "orgDomain"},
},
fmt.Sprintf(ClaimProjectRolesFormat, "project1"): projectRoles{
"role1": {"orgID": "orgDomain"},
"role2": {"orgID": "orgDomain"},
},
},
},
},
{
name: "human, requested roles",
args: args{
projectID: "project1",
user: humanUserInfo,
roleAudience: []string{"project1"},
requestedRoles: []string{"role2"},
},
want: &oidc.UserInfo{
Claims: map[string]any{
ClaimProjectRoles: projectRoles{
"role2": {"orgID": "orgDomain"},
},
fmt.Sprintf(ClaimProjectRolesFormat, "project1"): projectRoles{
"role2": {"orgID": "orgDomain"},
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assetPrefix := "https://foo.com/assets"
got := userInfoToOIDC(tt.args.projectID, tt.args.user, tt.args.scope, tt.args.roleAudience, tt.args.requestedRoles, assetPrefix)
assert.Equal(t, tt.want, got)
})
}
}