mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-12 01:37:31 +00:00
perf(oidc): optimize the introspection endpoint (#6909)
* get key by id and cache them
* userinfo from events for v2 tokens
* improve keyset caching
* concurrent token and client checks
* client and project in single query
* logging and otel
* drop owner_removed column on apps and authN tables
* userinfo and project roles in go routines
* get oidc user info from projections and add actions
* add avatar URL
* some cleanup
* pull oidc work branch
* remove storage from server
* add config flag for experimental introspection
* legacy introspection flag
* drop owner_removed column on user projections
* drop owner_removed column on useer_metadata
* query userinfo unit test
* query introspection client test
* add user_grants to the userinfo query
* handle PAT scopes
* bring triggers back
* test instance keys query
* add userinfo unit tests
* unit test keys
* go mod tidy
* solve some bugs
* fix missing preferred login name
* do not run triggers in go routines, they seem to deadlock
* initialize the trigger handlers late with a sync.OnceValue
* Revert "do not run triggers in go routines, they seem to deadlock"
This reverts commit 2a03da2127
.
* add missing translations
* chore: update go version for linting
* pin oidc version
* parse a global time location for query test
* fix linter complains
* upgrade go lint
* fix more linting issues
---------
Co-authored-by: Stefan Benz <46600784+stebenz@users.noreply.github.com>
This commit is contained in:
104
internal/api/oidc/access_token.go
Normal file
104
internal/api/oidc/access_token.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v3/pkg/op"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/command"
|
||||
errz "github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
"github.com/zitadel/zitadel/internal/user/model"
|
||||
)
|
||||
|
||||
type accessToken struct {
|
||||
tokenID string
|
||||
userID string
|
||||
subject string
|
||||
clientID string
|
||||
audience []string
|
||||
scope []string
|
||||
tokenCreation time.Time
|
||||
tokenExpiration time.Time
|
||||
isPAT bool
|
||||
}
|
||||
|
||||
func (s *Server) verifyAccessToken(ctx context.Context, tkn string) (*accessToken, error) {
|
||||
var tokenID, subject string
|
||||
|
||||
if tokenIDSubject, err := s.Provider().Crypto().Decrypt(tkn); err == nil {
|
||||
split := strings.Split(tokenIDSubject, ":")
|
||||
if len(split) != 2 {
|
||||
return nil, errors.New("invalid token format")
|
||||
}
|
||||
tokenID, subject = split[0], split[1]
|
||||
} else {
|
||||
verifier := op.NewAccessTokenVerifier(op.IssuerFromContext(ctx), s.keySet)
|
||||
claims, err := op.VerifyAccessToken[*oidc.AccessTokenClaims](ctx, tkn, verifier)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tokenID, subject = claims.JWTID, claims.Subject
|
||||
}
|
||||
|
||||
if strings.HasPrefix(tokenID, command.IDPrefixV2) {
|
||||
token, err := s.query.ActiveAccessTokenByToken(ctx, tokenID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return accessTokenV2(tokenID, subject, token), nil
|
||||
}
|
||||
|
||||
token, err := s.repo.TokenByIDs(ctx, subject, tokenID)
|
||||
if err != nil {
|
||||
return nil, errz.ThrowPermissionDenied(err, "OIDC-Dsfb2", "token is not valid or has expired")
|
||||
}
|
||||
return accessTokenV1(tokenID, subject, token), nil
|
||||
}
|
||||
|
||||
func accessTokenV1(tokenID, subject string, token *model.TokenView) *accessToken {
|
||||
return &accessToken{
|
||||
tokenID: tokenID,
|
||||
userID: token.UserID,
|
||||
subject: subject,
|
||||
clientID: token.ApplicationID,
|
||||
audience: token.Audience,
|
||||
scope: token.Scopes,
|
||||
tokenCreation: token.CreationDate,
|
||||
tokenExpiration: token.Expiration,
|
||||
isPAT: token.IsPAT,
|
||||
}
|
||||
}
|
||||
|
||||
func accessTokenV2(tokenID, subject string, token *query.OIDCSessionAccessTokenReadModel) *accessToken {
|
||||
return &accessToken{
|
||||
tokenID: tokenID,
|
||||
userID: token.UserID,
|
||||
subject: subject,
|
||||
clientID: token.ClientID,
|
||||
audience: token.Audience,
|
||||
scope: token.Scope,
|
||||
tokenCreation: token.AccessTokenCreation,
|
||||
tokenExpiration: token.AccessTokenExpiration,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) assertClientScopesForPAT(ctx context.Context, token *accessToken, clientID, projectID string) error {
|
||||
token.audience = append(token.audience, clientID)
|
||||
projectIDQuery, err := query.NewProjectRoleProjectIDSearchQuery(projectID)
|
||||
if err != nil {
|
||||
return errz.ThrowInternal(err, "OIDC-Cyc78", "Errors.Internal")
|
||||
}
|
||||
roles, err := s.query.SearchProjectRoles(ctx, s.features.TriggerIntrospectionProjections, &query.ProjectRoleSearchQueries{Queries: []query.SearchQuery{projectIDQuery}})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, role := range roles.ProjectRoles {
|
||||
token.scope = append(token.scope, ScopeProjectRolePrefix+role.Key)
|
||||
}
|
||||
return nil
|
||||
}
|
@@ -102,7 +102,7 @@ func (o *OPStorage) audienceFromProjectID(ctx context.Context, projectID string)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
appIDs, err := o.query.SearchClientIDs(ctx, &query.AppSearchQueries{Queries: []query.SearchQuery{projectIDQuery}}, false)
|
||||
appIDs, err := o.query.SearchClientIDs(ctx, &query.AppSearchQueries{Queries: []query.SearchQuery{projectIDQuery}})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -432,7 +432,7 @@ func (o *OPStorage) assertProjectRoleScopes(ctx context.Context, clientID string
|
||||
return scopes, nil
|
||||
}
|
||||
}
|
||||
projectID, err := o.query.ProjectIDFromOIDCClientID(ctx, clientID, false)
|
||||
projectID, err := o.query.ProjectIDFromOIDCClientID(ctx, clientID)
|
||||
if err != nil {
|
||||
return nil, errors.ThrowPreconditionFailed(nil, "OIDC-AEG4d", "Errors.Internal")
|
||||
}
|
||||
|
@@ -43,7 +43,7 @@ const (
|
||||
func (o *OPStorage) GetClientByClientID(ctx context.Context, id string) (_ op.Client, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
client, err := o.query.AppByOIDCClientID(ctx, id, false)
|
||||
client, err := o.query.AppByOIDCClientID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -94,7 +94,7 @@ func (o *OPStorage) GetKeyByIDAndIssuer(ctx context.Context, keyID, issuer strin
|
||||
}
|
||||
|
||||
func (o *OPStorage) ValidateJWTProfileScopes(ctx context.Context, subject string, scopes []string) ([]string, error) {
|
||||
user, err := o.query.GetUserByID(ctx, true, subject, false)
|
||||
user, err := o.query.GetUserByID(ctx, true, subject)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -108,7 +108,7 @@ func (o *OPStorage) AuthorizeClientIDSecret(ctx context.Context, id string, secr
|
||||
UserID: oidcCtx,
|
||||
OrgID: oidcCtx,
|
||||
})
|
||||
app, err := o.query.AppByClientID(ctx, id, false)
|
||||
app, err := o.query.AppByClientID(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -149,7 +149,7 @@ func (o *OPStorage) SetUserinfoFromScopes(ctx context.Context, userInfo *oidc.Us
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
if applicationID != "" {
|
||||
app, err := o.query.AppByOIDCClientID(ctx, applicationID, false)
|
||||
app, err := o.query.AppByOIDCClientID(ctx, applicationID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -184,7 +184,7 @@ func (o *OPStorage) SetIntrospectionFromToken(ctx context.Context, introspection
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
projectID, err := o.query.ProjectIDFromClientID(ctx, clientID, false)
|
||||
projectID, err := o.query.ProjectIDFromClientID(ctx, clientID)
|
||||
if err != nil {
|
||||
return errors.ThrowPermissionDenied(nil, "OIDC-Adfg5", "client not found")
|
||||
}
|
||||
@@ -198,7 +198,7 @@ func (o *OPStorage) SetIntrospectionFromToken(ctx context.Context, introspection
|
||||
if err != nil {
|
||||
return errors.ThrowPermissionDenied(nil, "OIDC-Dsfb2", "token is not valid or has expired")
|
||||
}
|
||||
projectID, err := o.query.ProjectIDFromClientID(ctx, clientID, false)
|
||||
projectID, err := o.query.ProjectIDFromClientID(ctx, clientID)
|
||||
if err != nil {
|
||||
return errors.ThrowPermissionDenied(nil, "OIDC-Adfg5", "client not found")
|
||||
}
|
||||
@@ -219,7 +219,7 @@ func (o *OPStorage) ClientCredentialsTokenRequest(ctx context.Context, clientID
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
user, err := o.query.GetUser(ctx, false, false, loginname)
|
||||
user, err := o.query.GetUser(ctx, false, loginname)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -240,7 +240,7 @@ func (o *OPStorage) ClientCredentials(ctx context.Context, clientID, clientSecre
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
user, err := o.query.GetUser(ctx, false, false, loginname)
|
||||
user, err := o.query.GetUser(ctx, false, loginname)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -259,7 +259,7 @@ func (o *OPStorage) isOriginAllowed(ctx context.Context, clientID, origin string
|
||||
if origin == "" {
|
||||
return nil
|
||||
}
|
||||
app, err := o.query.AppByOIDCClientID(ctx, clientID, false)
|
||||
app, err := o.query.AppByOIDCClientID(ctx, clientID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -331,7 +331,7 @@ func (o *OPStorage) checkOrgScopes(ctx context.Context, user *query.User, scopes
|
||||
func (o *OPStorage) setUserinfo(ctx context.Context, userInfo *oidc.UserInfo, userID, applicationID string, scopes []string, roleAudience []string) (err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
user, err := o.query.GetUserByID(ctx, true, userID, false)
|
||||
user, err := o.query.GetUserByID(ctx, true, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -645,7 +645,7 @@ func (o *OPStorage) GetPrivateClaimsFromScopes(ctx context.Context, userID, clie
|
||||
}
|
||||
|
||||
func (o *OPStorage) privateClaimsFlows(ctx context.Context, userID string, userGrants *query.UserGrants, claims map[string]interface{}) (map[string]interface{}, error) {
|
||||
user, err := o.query.GetUserByID(ctx, true, userID, false)
|
||||
user, err := o.query.GetUserByID(ctx, true, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -764,7 +764,7 @@ func (o *OPStorage) assertRoles(ctx context.Context, userID, applicationID strin
|
||||
if (applicationID == "" || len(requestedRoles) == 0) && len(roleAudience) == 0 {
|
||||
return nil, nil, nil
|
||||
}
|
||||
projectID, err := o.query.ProjectIDFromClientID(ctx, applicationID, false)
|
||||
projectID, err := o.query.ProjectIDFromClientID(ctx, applicationID)
|
||||
// applicationID might contain a username (e.g. client credentials) -> ignore the not found
|
||||
if err != nil && !errors.IsNotFound(err) {
|
||||
return nil, nil, err
|
||||
@@ -795,7 +795,7 @@ func (o *OPStorage) assertRoles(ctx context.Context, userID, applicationID strin
|
||||
if len(requestedRoles) > 0 {
|
||||
for _, requestedRole := range requestedRoles {
|
||||
for _, grant := range grants.UserGrants {
|
||||
checkGrantedRoles(roles, grant, requestedRole, grant.ProjectID == projectID)
|
||||
checkGrantedRoles(roles, *grant, requestedRole, grant.ProjectID == projectID)
|
||||
}
|
||||
}
|
||||
return grants, roles, nil
|
||||
@@ -823,7 +823,7 @@ func (o *OPStorage) assertUserMetaData(ctx context.Context, userID string) (map[
|
||||
}
|
||||
|
||||
func (o *OPStorage) assertUserResourceOwner(ctx context.Context, userID string) (map[string]string, error) {
|
||||
user, err := o.query.GetUserByID(ctx, true, userID, false)
|
||||
user, err := o.query.GetUserByID(ctx, true, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -838,7 +838,7 @@ func (o *OPStorage) assertUserResourceOwner(ctx context.Context, userID string)
|
||||
}, nil
|
||||
}
|
||||
|
||||
func checkGrantedRoles(roles *projectsRoles, grant *query.UserGrant, requestedRole string, isRequested bool) {
|
||||
func checkGrantedRoles(roles *projectsRoles, grant query.UserGrant, requestedRole string, isRequested bool) {
|
||||
for _, grantedRole := range grant.Roles {
|
||||
if requestedRole == grantedRole {
|
||||
roles.Add(grant.ProjectID, grantedRole, grant.ResourceOwner, grant.OrgPrimaryDomain, isRequested)
|
||||
@@ -854,6 +854,26 @@ type projectsRoles struct {
|
||||
requestProjectID string
|
||||
}
|
||||
|
||||
func newProjectRoles(projectID string, grants []query.UserGrant, requestedRoles []string) *projectsRoles {
|
||||
roles := new(projectsRoles)
|
||||
// if specific roles where requested, check if they are granted and append them in the roles list
|
||||
if len(requestedRoles) > 0 {
|
||||
for _, requestedRole := range requestedRoles {
|
||||
for _, grant := range grants {
|
||||
checkGrantedRoles(roles, grant, requestedRole, grant.ProjectID == projectID)
|
||||
}
|
||||
}
|
||||
return roles
|
||||
}
|
||||
// no specific roles were requested, so convert any grants into roles
|
||||
for _, grant := range grants {
|
||||
for _, role := range grant.Roles {
|
||||
roles.Add(grant.ProjectID, role, grant.ResourceOwner, grant.OrgPrimaryDomain, grant.ProjectID == projectID)
|
||||
}
|
||||
}
|
||||
return roles
|
||||
}
|
||||
|
||||
func (p *projectsRoles) Add(projectID, roleKey, orgID, domain string, isRequested bool) {
|
||||
if p.projects == nil {
|
||||
p.projects = make(map[string]projectRoles, 1)
|
||||
|
@@ -48,7 +48,7 @@ func TestOPStorage_SetUserinfoFromToken(t *testing.T) {
|
||||
assertUserinfo(t, userinfo)
|
||||
}
|
||||
|
||||
func TestOPStorage_SetIntrospectionFromToken(t *testing.T) {
|
||||
func TestServer_Introspect(t *testing.T) {
|
||||
project, err := Tester.CreateProject(CTX)
|
||||
require.NoError(t, err)
|
||||
app, err := Tester.CreateOIDCNativeClient(CTX, redirectURI, logoutRedirectURI, project.GetId())
|
||||
|
200
internal/api/oidc/introspect.go
Normal file
200
internal/api/oidc/introspect.go
Normal file
@@ -0,0 +1,200 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v3/pkg/op"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
errz "github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
)
|
||||
|
||||
func (s *Server) Introspect(ctx context.Context, r *op.Request[op.IntrospectionRequest]) (resp *op.Response, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
if s.features.LegacyIntrospection {
|
||||
return s.LegacyServer.Introspect(ctx, r)
|
||||
}
|
||||
if s.features.TriggerIntrospectionProjections {
|
||||
// Execute all triggers in one concurrent sweep.
|
||||
query.TriggerIntrospectionProjections(ctx)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
clientChan := make(chan *instrospectionClientResult)
|
||||
go s.instrospectionClientAuth(ctx, r.Data.ClientCredentials, clientChan)
|
||||
|
||||
tokenChan := make(chan *introspectionTokenResult)
|
||||
go s.introspectionToken(ctx, r.Data.Token, tokenChan)
|
||||
|
||||
var (
|
||||
client *instrospectionClientResult
|
||||
token *introspectionTokenResult
|
||||
)
|
||||
|
||||
// make sure both channels are always read,
|
||||
// and cancel the context on first error
|
||||
for i := 0; i < 2; i++ {
|
||||
var resErr error
|
||||
|
||||
select {
|
||||
case client = <-clientChan:
|
||||
resErr = client.err
|
||||
case token = <-tokenChan:
|
||||
resErr = token.err
|
||||
}
|
||||
|
||||
if resErr == nil {
|
||||
continue
|
||||
}
|
||||
cancel()
|
||||
|
||||
// we only care for the first error that occurred,
|
||||
// as the next error is most probably a context error.
|
||||
if err == nil {
|
||||
err = resErr
|
||||
}
|
||||
}
|
||||
|
||||
// only client auth errors should be returned
|
||||
var target *oidc.Error
|
||||
if errors.As(err, &target) && target.ErrorType == oidc.UnauthorizedClient {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// remaining errors shoudn't be returned to the client,
|
||||
// so we catch errors here, log them and return the response
|
||||
// with active: false
|
||||
defer func() {
|
||||
if err != nil {
|
||||
s.getLogger(ctx).ErrorContext(ctx, "oidc introspection", "err", err)
|
||||
resp, err = op.NewResponse(new(oidc.IntrospectionResponse)), nil
|
||||
}
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO: can we get rid of this separate query?
|
||||
if token.isPAT {
|
||||
if err = s.assertClientScopesForPAT(ctx, token.accessToken, client.clientID, client.projectID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if err = validateIntrospectionAudience(token.audience, client.clientID, client.projectID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userInfo, err := s.userInfo(ctx, token.userID, client.projectID, token.scope, []string{client.projectID})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
introspectionResp := &oidc.IntrospectionResponse{
|
||||
Active: true,
|
||||
Scope: token.scope,
|
||||
ClientID: token.clientID,
|
||||
TokenType: oidc.BearerToken,
|
||||
Expiration: oidc.FromTime(token.tokenExpiration),
|
||||
IssuedAt: oidc.FromTime(token.tokenCreation),
|
||||
NotBefore: oidc.FromTime(token.tokenCreation),
|
||||
Audience: token.audience,
|
||||
Issuer: op.IssuerFromContext(ctx),
|
||||
JWTID: token.tokenID,
|
||||
}
|
||||
introspectionResp.SetUserInfo(userInfo)
|
||||
return op.NewResponse(introspectionResp), nil
|
||||
}
|
||||
|
||||
type instrospectionClientResult struct {
|
||||
clientID string
|
||||
projectID string
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *Server) instrospectionClientAuth(ctx context.Context, cc *op.ClientCredentials, rc chan<- *instrospectionClientResult) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
|
||||
clientID, projectID, err := func() (string, string, error) {
|
||||
client, err := s.clientFromCredentials(ctx, cc)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
if cc.ClientAssertion != "" {
|
||||
verifier := op.NewJWTProfileVerifierKeySet(keySetMap(client.PublicKeys), op.IssuerFromContext(ctx), time.Hour, time.Second)
|
||||
if _, err := op.VerifyJWTAssertion(ctx, cc.ClientAssertion, verifier); err != nil {
|
||||
return "", "", oidc.ErrUnauthorizedClient().WithParent(err)
|
||||
}
|
||||
} else {
|
||||
if err := crypto.CompareHash(client.ClientSecret, []byte(cc.ClientSecret), s.hashAlg); err != nil {
|
||||
return "", "", oidc.ErrUnauthorizedClient().WithParent(err)
|
||||
}
|
||||
}
|
||||
|
||||
return client.ClientID, client.ProjectID, nil
|
||||
}()
|
||||
|
||||
span.EndWithError(err)
|
||||
|
||||
rc <- &instrospectionClientResult{
|
||||
clientID: clientID,
|
||||
projectID: projectID,
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// clientFromCredentials parses the client ID early,
|
||||
// and makes a single query for the client for either auth methods.
|
||||
func (s *Server) clientFromCredentials(ctx context.Context, cc *op.ClientCredentials) (client *query.IntrospectionClient, err error) {
|
||||
if cc.ClientAssertion != "" {
|
||||
claims := new(oidc.JWTTokenRequest)
|
||||
if _, err := oidc.ParseToken(cc.ClientAssertion, claims); err != nil {
|
||||
return nil, oidc.ErrUnauthorizedClient().WithParent(err)
|
||||
}
|
||||
client, err = s.query.GetIntrospectionClientByID(ctx, claims.Issuer, true)
|
||||
} else {
|
||||
client, err = s.query.GetIntrospectionClientByID(ctx, cc.ClientID, false)
|
||||
}
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, oidc.ErrUnauthorizedClient().WithParent(err)
|
||||
}
|
||||
// any other error is regarded internal and should not be reported back to the client.
|
||||
return client, err
|
||||
}
|
||||
|
||||
type introspectionTokenResult struct {
|
||||
*accessToken
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *Server) introspectionToken(ctx context.Context, tkn string, rc chan<- *introspectionTokenResult) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
token, err := s.verifyAccessToken(ctx, tkn)
|
||||
span.EndWithError(err)
|
||||
|
||||
rc <- &introspectionTokenResult{
|
||||
accessToken: token,
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
|
||||
func validateIntrospectionAudience(audience []string, clientID, projectID string) error {
|
||||
if slices.ContainsFunc(audience, func(entry string) bool {
|
||||
return entry == clientID || entry == projectID
|
||||
}) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return errz.ThrowPermissionDenied(nil, "OIDC-sdg3G", "token is not valid for this client")
|
||||
}
|
@@ -12,7 +12,7 @@ import (
|
||||
|
||||
func (o *OPStorage) JWTProfileTokenType(ctx context.Context, request op.TokenRequest) (op.AccessTokenType, error) {
|
||||
mapJWTProfileScopesToAudience(ctx, request)
|
||||
user, err := o.query.GetUserByID(ctx, false, request.GetSubject(), false)
|
||||
user, err := o.query.GetUserByID(ctx, false, request.GetSubject())
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
@@ -3,14 +3,17 @@ package oidc
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-jose/go-jose/v3"
|
||||
"github.com/jonboulle/clockwork"
|
||||
"github.com/zitadel/logging"
|
||||
"github.com/zitadel/oidc/v3/pkg/op"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
@@ -19,6 +22,145 @@ import (
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
)
|
||||
|
||||
// keySetCache implements oidc.KeySet for Access Token verification.
|
||||
// Public Keys are cached in a 2-dimensional map of Instance ID and Key ID.
|
||||
// When a key is not present the queryKey function is called to obtain the key
|
||||
// from the database.
|
||||
type keySetCache struct {
|
||||
mtx sync.RWMutex
|
||||
instanceKeys map[string]map[string]query.PublicKey
|
||||
queryKey func(ctx context.Context, keyID string, current time.Time) (query.PublicKey, error)
|
||||
clock clockwork.Clock
|
||||
}
|
||||
|
||||
// newKeySet initializes a keySetCache and starts a purging Go routine,
|
||||
// which runs once every purgeInterval.
|
||||
// When the passed context is done, the purge routine will terminate.
|
||||
func newKeySet(background context.Context, purgeInterval time.Duration, queryKey func(ctx context.Context, keyID string, current time.Time) (query.PublicKey, error)) *keySetCache {
|
||||
k := &keySetCache{
|
||||
instanceKeys: make(map[string]map[string]query.PublicKey),
|
||||
queryKey: queryKey,
|
||||
clock: clockwork.FromContext(background), // defaults to real clock
|
||||
}
|
||||
go k.purgeOnInterval(background, k.clock.NewTicker(purgeInterval))
|
||||
return k
|
||||
}
|
||||
|
||||
func (k *keySetCache) purgeOnInterval(background context.Context, ticker clockwork.Ticker) {
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-background.Done():
|
||||
return
|
||||
case <-ticker.Chan():
|
||||
}
|
||||
|
||||
// do the actual purging
|
||||
k.mtx.Lock()
|
||||
for instanceID, keys := range k.instanceKeys {
|
||||
for keyID, key := range keys {
|
||||
if key.Expiry().Before(k.clock.Now()) {
|
||||
delete(keys, keyID)
|
||||
}
|
||||
}
|
||||
if len(keys) == 0 {
|
||||
delete(k.instanceKeys, instanceID)
|
||||
}
|
||||
}
|
||||
k.mtx.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (k *keySetCache) setKey(instanceID, keyID string, key query.PublicKey) {
|
||||
k.mtx.Lock()
|
||||
defer k.mtx.Unlock()
|
||||
|
||||
if keys, ok := k.instanceKeys[instanceID]; ok {
|
||||
keys[keyID] = key
|
||||
return
|
||||
}
|
||||
|
||||
k.instanceKeys[instanceID] = map[string]query.PublicKey{keyID: key}
|
||||
}
|
||||
|
||||
func (k *keySetCache) getKey(ctx context.Context, keyID string) (_ *jose.JSONWebKey, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
instanceID := authz.GetInstance(ctx).InstanceID()
|
||||
|
||||
k.mtx.RLock()
|
||||
key, ok := k.instanceKeys[instanceID][keyID]
|
||||
k.mtx.RUnlock()
|
||||
|
||||
if ok {
|
||||
if key.Expiry().After(k.clock.Now()) {
|
||||
return jsonWebkey(key), nil
|
||||
}
|
||||
return nil, errors.ThrowInvalidArgument(nil, "OIDC-Zoh9E", "Errors.Key.ExpireBeforeNow")
|
||||
}
|
||||
|
||||
key, err = k.queryKey(ctx, keyID, k.clock.Now())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
k.setKey(instanceID, keyID, key)
|
||||
return jsonWebkey(key), nil
|
||||
}
|
||||
|
||||
// VerifySignature implements the oidc.KeySet interface.
|
||||
func (k *keySetCache) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) (_ []byte, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
if len(jws.Signatures) != 1 {
|
||||
return nil, errors.ThrowInvalidArgument(nil, "OIDC-Gid9s", "Errors.Token.Invalid")
|
||||
}
|
||||
key, err := k.getKey(ctx, jws.Signatures[0].Header.KeyID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return jws.Verify(key)
|
||||
}
|
||||
|
||||
func jsonWebkey(key query.PublicKey) *jose.JSONWebKey {
|
||||
return &jose.JSONWebKey{
|
||||
KeyID: key.ID(),
|
||||
Algorithm: key.Algorithm(),
|
||||
Use: key.Use().String(),
|
||||
Key: key.Key(),
|
||||
}
|
||||
}
|
||||
|
||||
// keySetMap is a mapping of key IDs to public key data.
|
||||
type keySetMap map[string][]byte
|
||||
|
||||
// getKey finds the keyID and parses the public key data
|
||||
// into a JSONWebKey.
|
||||
func (k keySetMap) getKey(keyID string) (*jose.JSONWebKey, error) {
|
||||
pubKey, err := crypto.BytesToPublicKey(k[keyID])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &jose.JSONWebKey{
|
||||
Key: pubKey,
|
||||
KeyID: keyID,
|
||||
Use: domain.KeyUsageSigning.String(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// VerifySignature implements the oidc.KeySet interface.
|
||||
func (k keySetMap) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) {
|
||||
if len(jws.Signatures) != 1 {
|
||||
return nil, errors.ThrowInvalidArgument(nil, "OIDC-Eeth6", "Errors.Token.Invalid")
|
||||
}
|
||||
key, err := k.getKey(jws.Signatures[0].Header.KeyID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return jws.Verify(key)
|
||||
}
|
||||
|
||||
const (
|
||||
locksTable = "projections.locks"
|
||||
signingKey = "signing_key"
|
||||
|
244
internal/api/oidc/key_test.go
Normal file
244
internal/api/oidc/key_test.go
Normal file
@@ -0,0 +1,244 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-jose/go-jose/v3"
|
||||
"github.com/jonboulle/clockwork"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
)
|
||||
|
||||
type publicKey struct {
|
||||
id string
|
||||
alg string
|
||||
use domain.KeyUsage
|
||||
seq uint64
|
||||
expiry time.Time
|
||||
key any
|
||||
}
|
||||
|
||||
func (k *publicKey) ID() string {
|
||||
return k.id
|
||||
}
|
||||
|
||||
func (k *publicKey) Algorithm() string {
|
||||
return k.alg
|
||||
}
|
||||
|
||||
func (k *publicKey) Use() domain.KeyUsage {
|
||||
return k.use
|
||||
}
|
||||
|
||||
func (k *publicKey) Sequence() uint64 {
|
||||
return k.seq
|
||||
}
|
||||
|
||||
func (k *publicKey) Expiry() time.Time {
|
||||
return k.expiry
|
||||
}
|
||||
|
||||
func (k *publicKey) Key() any {
|
||||
return k.key
|
||||
}
|
||||
|
||||
var (
|
||||
clock = clockwork.NewFakeClock()
|
||||
keyDB = map[string]*publicKey{
|
||||
"key1": {
|
||||
id: "key1",
|
||||
alg: "alg",
|
||||
use: domain.KeyUsageSigning,
|
||||
seq: 1,
|
||||
expiry: clock.Now().Add(time.Minute),
|
||||
},
|
||||
"key2": {
|
||||
id: "key2",
|
||||
alg: "alg",
|
||||
use: domain.KeyUsageSigning,
|
||||
seq: 3,
|
||||
expiry: clock.Now().Add(10 * time.Hour),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
func queryKeyDB(_ context.Context, keyID string, current time.Time) (query.PublicKey, error) {
|
||||
if key, ok := keyDB[keyID]; ok {
|
||||
return key, nil
|
||||
}
|
||||
return nil, errors.New("not found")
|
||||
}
|
||||
|
||||
func Test_keySetCache(t *testing.T) {
|
||||
background, cancel := context.WithCancel(
|
||||
clockwork.AddToContext(context.Background(), clock),
|
||||
)
|
||||
defer cancel()
|
||||
|
||||
// create an empty keySet with a purge go routine, runs every Hour
|
||||
keySet := newKeySet(background, time.Hour, queryKeyDB)
|
||||
ctx := authz.NewMockContext("instanceID", "orgID", "userID")
|
||||
|
||||
// query error
|
||||
_, err := keySet.getKey(ctx, "key9")
|
||||
require.Error(t, err)
|
||||
|
||||
want := &jose.JSONWebKey{
|
||||
KeyID: "key1",
|
||||
Algorithm: "alg",
|
||||
Use: domain.KeyUsageSigning.String(),
|
||||
}
|
||||
|
||||
// get key first time, populate the cache
|
||||
got, err := keySet.getKey(ctx, "key1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, want, got)
|
||||
|
||||
// move time forward
|
||||
clock.Advance(5 * time.Minute)
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
// key should still be in cache
|
||||
keySet.mtx.RLock()
|
||||
_, ok := keySet.instanceKeys["instanceID"]["key1"]
|
||||
require.True(t, ok)
|
||||
keySet.mtx.RUnlock()
|
||||
|
||||
// the key is expired, should error
|
||||
_, err = keySet.getKey(ctx, "key1")
|
||||
require.Error(t, err)
|
||||
|
||||
want = &jose.JSONWebKey{
|
||||
KeyID: "key2",
|
||||
Algorithm: "alg",
|
||||
Use: domain.KeyUsageSigning.String(),
|
||||
}
|
||||
|
||||
// get the second key from DB
|
||||
got, err = keySet.getKey(ctx, "key2")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, want, got)
|
||||
|
||||
// move time forward
|
||||
clock.Advance(time.Hour)
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
// first key shoud be purged, second still present
|
||||
keySet.mtx.RLock()
|
||||
_, ok = keySet.instanceKeys["instanceID"]["key1"]
|
||||
require.False(t, ok)
|
||||
_, ok = keySet.instanceKeys["instanceID"]["key2"]
|
||||
require.True(t, ok)
|
||||
keySet.mtx.RUnlock()
|
||||
|
||||
// get the second key from cache
|
||||
got, err = keySet.getKey(ctx, "key2")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, want, got)
|
||||
|
||||
// move time forward
|
||||
clock.Advance(10 * time.Hour)
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
// now the cache should be empty
|
||||
keySet.mtx.RLock()
|
||||
assert.Empty(t, keySet.instanceKeys)
|
||||
keySet.mtx.RUnlock()
|
||||
}
|
||||
|
||||
func Test_keySetCache_VerifySignature(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
k := newKeySet(ctx, time.Second, queryKeyDB)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
jws *jose.JSONWebSignature
|
||||
}{
|
||||
{
|
||||
name: "invalid token",
|
||||
jws: &jose.JSONWebSignature{},
|
||||
},
|
||||
{
|
||||
name: "key not found",
|
||||
jws: &jose.JSONWebSignature{
|
||||
Signatures: []jose.Signature{{
|
||||
Header: jose.Header{
|
||||
KeyID: "xxx",
|
||||
},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "verify error",
|
||||
jws: &jose.JSONWebSignature{
|
||||
Signatures: []jose.Signature{{
|
||||
Header: jose.Header{
|
||||
KeyID: "key1",
|
||||
},
|
||||
}},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := k.VerifySignature(ctx, tt.jws)
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_keySetMap_VerifySignature(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
k keySetMap
|
||||
jws *jose.JSONWebSignature
|
||||
}{
|
||||
{
|
||||
name: "invalid signature",
|
||||
k: keySetMap{
|
||||
"key1": []byte("foo"),
|
||||
},
|
||||
jws: &jose.JSONWebSignature{},
|
||||
},
|
||||
{
|
||||
name: "parse error",
|
||||
k: keySetMap{
|
||||
"key1": []byte("foo"),
|
||||
},
|
||||
jws: &jose.JSONWebSignature{
|
||||
Signatures: []jose.Signature{{
|
||||
Header: jose.Header{
|
||||
KeyID: "key1",
|
||||
},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "verify error",
|
||||
k: keySetMap{
|
||||
"key1": []byte("-----BEGIN RSA PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAsvX9P58JFxEs5C+L+H7W\nduFSWL5EPzber7C2m94klrSV6q0bAcrYQnGwFOlveThsY200hRbadKaKjHD7qIKH\nDEe0IY2PSRht33Jye52AwhkRw+M3xuQH/7R8LydnsNFk2KHpr5X2SBv42e37LjkE\nslKSaMRgJW+v0KZ30piY8QsdFRKKaVg5/Ajt1YToM1YVsdHXJ3vmXFMtypLdxwUD\ndIaLEX6pFUkU75KSuEQ/E2luT61Q3ta9kOWm9+0zvi7OMcbdekJT7mzcVnh93R1c\n13ZhQCLbh9A7si8jKFtaMWevjayrvqQABEcTN9N4Hoxcyg6l4neZtRDk75OMYcqm\nDQIDAQAB\n-----END RSA PUBLIC KEY-----\n"),
|
||||
},
|
||||
jws: &jose.JSONWebSignature{
|
||||
Signatures: []jose.Signature{{
|
||||
Header: jose.Header{
|
||||
KeyID: "key1",
|
||||
},
|
||||
}},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := tt.k.VerifySignature(context.Background(), tt.jws)
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
}
|
@@ -45,6 +45,7 @@ type Config struct {
|
||||
DeviceAuth *DeviceAuthorizationConfig
|
||||
DefaultLoginURLV2 string
|
||||
DefaultLogoutURLV2 string
|
||||
Features Features
|
||||
}
|
||||
|
||||
type EndpointConfig struct {
|
||||
@@ -63,6 +64,11 @@ type Endpoint struct {
|
||||
URL string
|
||||
}
|
||||
|
||||
type Features struct {
|
||||
TriggerIntrospectionProjections bool
|
||||
LegacyIntrospection bool
|
||||
}
|
||||
|
||||
type OPStorage struct {
|
||||
repo repository.Repository
|
||||
command *command.Commands
|
||||
@@ -120,7 +126,15 @@ func NewServer(
|
||||
|
||||
server := &Server{
|
||||
LegacyServer: op.NewLegacyServer(provider, endpoints(config.CustomEndpoints)),
|
||||
features: config.Features,
|
||||
repo: repo,
|
||||
query: query,
|
||||
command: command,
|
||||
keySet: newKeySet(context.TODO(), time.Hour, query.GetActivePublicKeyByID),
|
||||
fallbackLogger: fallbackLogger,
|
||||
hashAlg: crypto.NewBCrypt(10), // as we are only verifying in oidc, the cost is already part of the hash string and the config here is irrelevant.
|
||||
signingKeyAlgorithm: config.SigningKeyAlgorithm,
|
||||
assetAPIPrefix: assets.AssetAPI(externalSecure),
|
||||
}
|
||||
metricTypes := []metrics.MetricType{metrics.MetricTypeRequestCount, metrics.MetricTypeStatusCode, metrics.MetricTypeTotalCount}
|
||||
server.Handler = op.RegisterLegacyServer(server, op.WithHTTPMiddleware(
|
||||
|
@@ -4,16 +4,32 @@ import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v3/pkg/op"
|
||||
"golang.org/x/exp/slog"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/auth/repository"
|
||||
"github.com/zitadel/zitadel/internal/command"
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
http.Handler
|
||||
*op.LegacyServer
|
||||
features Features
|
||||
|
||||
repo repository.Repository
|
||||
query *query.Queries
|
||||
command *command.Commands
|
||||
keySet *keySetCache
|
||||
|
||||
fallbackLogger *slog.Logger
|
||||
hashAlg crypto.HashAlgorithm
|
||||
signingKeyAlgorithm string
|
||||
assetAPIPrefix func(ctx context.Context) string
|
||||
}
|
||||
|
||||
func endpoints(endpointConfig *EndpointConfig) op.Endpoints {
|
||||
@@ -59,6 +75,13 @@ func endpoints(endpointConfig *EndpointConfig) op.Endpoints {
|
||||
return endpoints
|
||||
}
|
||||
|
||||
func (s *Server) getLogger(ctx context.Context) *slog.Logger {
|
||||
if logger, ok := logging.FromContext(ctx); ok {
|
||||
return logger
|
||||
}
|
||||
return s.fallbackLogger
|
||||
}
|
||||
|
||||
func (s *Server) IssuerFromRequest(r *http.Request) string {
|
||||
return s.Provider().IssuerFromRequest(r)
|
||||
}
|
||||
@@ -161,13 +184,6 @@ func (s *Server) DeviceToken(ctx context.Context, r *op.ClientRequest[oidc.Devic
|
||||
return s.LegacyServer.DeviceToken(ctx, r)
|
||||
}
|
||||
|
||||
func (s *Server) Introspect(ctx context.Context, r *op.Request[op.IntrospectionRequest]) (_ *op.Response, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
return s.LegacyServer.Introspect(ctx, r)
|
||||
}
|
||||
|
||||
func (s *Server) UserInfo(ctx context.Context, r *op.Request[oidc.UserInfoRequest]) (_ *op.Response, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
276
internal/api/oidc/userinfo.go
Normal file
276
internal/api/oidc/userinfo.go
Normal file
@@ -0,0 +1,276 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/dop251/goja"
|
||||
"github.com/zitadel/logging"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/actions"
|
||||
"github.com/zitadel/zitadel/internal/actions/object"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
)
|
||||
|
||||
func (s *Server) userInfo(ctx context.Context, userID, projectID string, scope, roleAudience []string) (_ *oidc.UserInfo, err error) {
|
||||
roleAudience, requestedRoles := prepareRoles(ctx, projectID, scope, roleAudience)
|
||||
qu, err := s.query.GetOIDCUserInfo(ctx, userID, roleAudience)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userInfo := userInfoToOIDC(projectID, qu, scope, roleAudience, requestedRoles, s.assetAPIPrefix(ctx))
|
||||
return userInfo, s.userinfoFlows(ctx, qu, userInfo)
|
||||
}
|
||||
|
||||
// prepareRoles scans the requested scopes, appends to roleAudiendce and returns the requestedRoles.
|
||||
//
|
||||
// When [ScopeProjectsRoles] is present and roleAudience was empty,
|
||||
// project IDs with the [domain.ProjectIDScope] prefix are added to the roleAudience.
|
||||
//
|
||||
// Scopes with [ScopeProjectRolePrefix] are added to requestedRoles.
|
||||
//
|
||||
// If the resulting requestedRoles or roleAudience are not not empty,
|
||||
// the current projectID will always be parts or roleAudience.
|
||||
// Else nil, nil is returned.
|
||||
func prepareRoles(ctx context.Context, projectID string, scope, roleAudience []string) (ra, requestedRoles []string) {
|
||||
// if all roles are requested take the audience for those from the scopes
|
||||
if slices.Contains(scope, ScopeProjectsRoles) && len(roleAudience) == 0 {
|
||||
roleAudience = domain.AddAudScopeToAudience(ctx, roleAudience, scope)
|
||||
}
|
||||
requestedRoles = make([]string, 0, len(scope))
|
||||
for _, s := range scope {
|
||||
if role, ok := strings.CutPrefix(s, ScopeProjectRolePrefix); ok {
|
||||
requestedRoles = append(requestedRoles, role)
|
||||
}
|
||||
}
|
||||
if len(requestedRoles) == 0 && len(roleAudience) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if projectID != "" && !slices.Contains(roleAudience, projectID) {
|
||||
roleAudience = append(roleAudience, projectID)
|
||||
}
|
||||
return roleAudience, requestedRoles
|
||||
}
|
||||
|
||||
func userInfoToOIDC(projectID string, user *query.OIDCUserInfo, scope, roleAudience, requestedRoles []string, assetPrefix string) *oidc.UserInfo {
|
||||
out := new(oidc.UserInfo)
|
||||
for _, s := range scope {
|
||||
switch s {
|
||||
case oidc.ScopeOpenID:
|
||||
out.Subject = user.User.ID
|
||||
case oidc.ScopeEmail:
|
||||
out.UserInfoEmail = userInfoEmailToOIDC(user.User)
|
||||
case oidc.ScopeProfile:
|
||||
out.UserInfoProfile = userInfoProfileToOidc(user.User, assetPrefix)
|
||||
case oidc.ScopePhone:
|
||||
out.UserInfoPhone = userInfoPhoneToOIDC(user.User)
|
||||
case oidc.ScopeAddress:
|
||||
//TODO: handle address for human users as soon as implemented
|
||||
case ScopeUserMetaData:
|
||||
setUserInfoMetadata(user.Metadata, out)
|
||||
case ScopeResourceOwner:
|
||||
setUserInfoOrgClaims(user, out)
|
||||
default:
|
||||
if claim, ok := strings.CutPrefix(s, domain.OrgDomainPrimaryScope); ok {
|
||||
out.AppendClaims(domain.OrgDomainPrimaryClaim, claim)
|
||||
}
|
||||
if claim, ok := strings.CutPrefix(s, domain.OrgIDScope); ok {
|
||||
out.AppendClaims(domain.OrgIDClaim, claim)
|
||||
setUserInfoOrgClaims(user, out)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// prevent returning obtained grants if none where requested
|
||||
if (projectID != "" && len(requestedRoles) > 0) || len(roleAudience) > 0 {
|
||||
setUserInfoRoleClaims(out, newProjectRoles(projectID, user.UserGrants, requestedRoles))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func userInfoEmailToOIDC(user *query.User) oidc.UserInfoEmail {
|
||||
if human := user.Human; human != nil {
|
||||
return oidc.UserInfoEmail{
|
||||
Email: string(human.Email),
|
||||
EmailVerified: oidc.Bool(human.IsEmailVerified),
|
||||
}
|
||||
}
|
||||
return oidc.UserInfoEmail{}
|
||||
}
|
||||
|
||||
func userInfoProfileToOidc(user *query.User, assetPrefix string) oidc.UserInfoProfile {
|
||||
if human := user.Human; human != nil {
|
||||
return oidc.UserInfoProfile{
|
||||
Name: human.DisplayName,
|
||||
GivenName: human.FirstName,
|
||||
FamilyName: human.LastName,
|
||||
Nickname: human.NickName,
|
||||
Picture: domain.AvatarURL(assetPrefix, user.ResourceOwner, user.Human.AvatarKey),
|
||||
Gender: getGender(human.Gender),
|
||||
Locale: oidc.NewLocale(human.PreferredLanguage),
|
||||
UpdatedAt: oidc.FromTime(user.ChangeDate),
|
||||
PreferredUsername: user.PreferredLoginName,
|
||||
}
|
||||
}
|
||||
if machine := user.Machine; machine != nil {
|
||||
return oidc.UserInfoProfile{
|
||||
Name: machine.Name,
|
||||
UpdatedAt: oidc.FromTime(user.ChangeDate),
|
||||
PreferredUsername: user.PreferredLoginName,
|
||||
}
|
||||
}
|
||||
return oidc.UserInfoProfile{}
|
||||
}
|
||||
|
||||
func userInfoPhoneToOIDC(user *query.User) oidc.UserInfoPhone {
|
||||
if human := user.Human; human != nil {
|
||||
return oidc.UserInfoPhone{
|
||||
PhoneNumber: string(human.Phone),
|
||||
PhoneNumberVerified: human.IsPhoneVerified,
|
||||
}
|
||||
}
|
||||
return oidc.UserInfoPhone{}
|
||||
}
|
||||
|
||||
func setUserInfoMetadata(metadata []query.UserMetadata, out *oidc.UserInfo) {
|
||||
if len(metadata) == 0 {
|
||||
return
|
||||
}
|
||||
mdmap := make(map[string]string, len(metadata))
|
||||
for _, md := range metadata {
|
||||
mdmap[md.Key] = base64.RawURLEncoding.EncodeToString(md.Value)
|
||||
}
|
||||
out.AppendClaims(ClaimUserMetaData, mdmap)
|
||||
}
|
||||
|
||||
func setUserInfoOrgClaims(user *query.OIDCUserInfo, out *oidc.UserInfo) {
|
||||
if org := user.Org; org != nil {
|
||||
out.AppendClaims(ClaimResourceOwner+"id", org.ID)
|
||||
out.AppendClaims(ClaimResourceOwner+"name", org.Name)
|
||||
out.AppendClaims(ClaimResourceOwner+"primary_domain", org.PrimaryDomain)
|
||||
}
|
||||
}
|
||||
|
||||
func setUserInfoRoleClaims(userInfo *oidc.UserInfo, roles *projectsRoles) {
|
||||
if roles != nil && len(roles.projects) > 0 {
|
||||
if roles, ok := roles.projects[roles.requestProjectID]; ok {
|
||||
userInfo.AppendClaims(ClaimProjectRoles, roles)
|
||||
}
|
||||
for projectID, roles := range roles.projects {
|
||||
userInfo.AppendClaims(fmt.Sprintf(ClaimProjectRolesFormat, projectID), roles)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) userinfoFlows(ctx context.Context, qu *query.OIDCUserInfo, userInfo *oidc.UserInfo) error {
|
||||
queriedActions, err := s.query.GetActiveActionsByFlowAndTriggerType(ctx, domain.FlowTypeCustomiseToken, domain.TriggerTypePreUserinfoCreation, qu.User.ResourceOwner)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctxFields := actions.SetContextFields(
|
||||
actions.SetFields("v1",
|
||||
actions.SetFields("claims", userinfoClaims(userInfo)),
|
||||
actions.SetFields("getUser", func(c *actions.FieldConfig) interface{} {
|
||||
return func(call goja.FunctionCall) goja.Value {
|
||||
return object.UserFromQuery(c, qu.User)
|
||||
}
|
||||
}),
|
||||
actions.SetFields("user",
|
||||
actions.SetFields("getMetadata", func(c *actions.FieldConfig) interface{} {
|
||||
return func(goja.FunctionCall) goja.Value {
|
||||
return object.UserMetadataListFromSlice(c, qu.Metadata)
|
||||
}
|
||||
}),
|
||||
actions.SetFields("grants", func(c *actions.FieldConfig) interface{} {
|
||||
return object.UserGrantsFromSlice(c, qu.UserGrants)
|
||||
}),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
for _, action := range queriedActions {
|
||||
actionCtx, cancel := context.WithTimeout(ctx, action.Timeout())
|
||||
claimLogs := []string{}
|
||||
|
||||
apiFields := actions.WithAPIFields(
|
||||
actions.SetFields("v1",
|
||||
actions.SetFields("userinfo",
|
||||
actions.SetFields("setClaim", func(key string, value interface{}) {
|
||||
if userInfo.Claims[key] == nil {
|
||||
userInfo.AppendClaims(key, value)
|
||||
return
|
||||
}
|
||||
claimLogs = append(claimLogs, fmt.Sprintf("key %q already exists", key))
|
||||
}),
|
||||
actions.SetFields("appendLogIntoClaims", func(entry string) {
|
||||
claimLogs = append(claimLogs, entry)
|
||||
}),
|
||||
),
|
||||
actions.SetFields("claims",
|
||||
actions.SetFields("setClaim", func(key string, value interface{}) {
|
||||
if userInfo.Claims[key] == nil {
|
||||
userInfo.AppendClaims(key, value)
|
||||
return
|
||||
}
|
||||
claimLogs = append(claimLogs, fmt.Sprintf("key %q already exists", key))
|
||||
}),
|
||||
actions.SetFields("appendLogIntoClaims", func(entry string) {
|
||||
claimLogs = append(claimLogs, entry)
|
||||
}),
|
||||
),
|
||||
actions.SetFields("user",
|
||||
actions.SetFields("setMetadata", func(call goja.FunctionCall) goja.Value {
|
||||
if len(call.Arguments) != 2 {
|
||||
panic("exactly 2 (key, value) arguments expected")
|
||||
}
|
||||
key := call.Arguments[0].Export().(string)
|
||||
val := call.Arguments[1].Export()
|
||||
|
||||
value, err := json.Marshal(val)
|
||||
if err != nil {
|
||||
logging.WithError(err).Debug("unable to marshal")
|
||||
panic(err)
|
||||
}
|
||||
|
||||
metadata := &domain.Metadata{
|
||||
Key: key,
|
||||
Value: value,
|
||||
}
|
||||
if _, err = s.command.SetUserMetadata(ctx, metadata, userInfo.Subject, qu.User.ResourceOwner); err != nil {
|
||||
logging.WithError(err).Info("unable to set md in action")
|
||||
panic(err)
|
||||
}
|
||||
return nil
|
||||
}),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
err = actions.Run(
|
||||
actionCtx,
|
||||
ctxFields,
|
||||
apiFields,
|
||||
action.Script,
|
||||
action.Name,
|
||||
append(actions.ActionToOptions(action), actions.WithHTTP(actionCtx), actions.WithUUID(actionCtx))...,
|
||||
)
|
||||
cancel()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(claimLogs) > 0 {
|
||||
userInfo.AppendClaims(fmt.Sprintf(ClaimActionLogFormat, action.Name), claimLogs)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
434
internal/api/oidc/userinfo_test.go
Normal file
434
internal/api/oidc/userinfo_test.go
Normal file
@@ -0,0 +1,434 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
)
|
||||
|
||||
func Test_prepareRoles(t *testing.T) {
|
||||
type args struct {
|
||||
projectID string
|
||||
scope []string
|
||||
roleAudience []string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantRa []string
|
||||
wantRequestedRoles []string
|
||||
}{
|
||||
{
|
||||
name: "empty scope and roleAudience",
|
||||
args: args{
|
||||
projectID: "projID",
|
||||
scope: nil,
|
||||
roleAudience: nil,
|
||||
},
|
||||
wantRa: nil,
|
||||
wantRequestedRoles: nil,
|
||||
},
|
||||
{
|
||||
name: "some scope and roleAudience",
|
||||
args: args{
|
||||
projectID: "projID",
|
||||
scope: []string{"openid", "profile"},
|
||||
roleAudience: []string{"project2"},
|
||||
},
|
||||
wantRa: []string{"project2", "projID"},
|
||||
wantRequestedRoles: []string{},
|
||||
},
|
||||
{
|
||||
name: "scope projects roles",
|
||||
args: args{
|
||||
projectID: "projID",
|
||||
scope: []string{ScopeProjectsRoles, domain.ProjectIDScope + "project2" + domain.AudSuffix},
|
||||
roleAudience: nil,
|
||||
},
|
||||
wantRa: []string{"project2", "projID"},
|
||||
wantRequestedRoles: []string{},
|
||||
},
|
||||
{
|
||||
name: "scope project role prefix",
|
||||
args: args{
|
||||
projectID: "projID",
|
||||
scope: []string{"openid", "profile", ScopeProjectRolePrefix + "foo", ScopeProjectRolePrefix + "bar"},
|
||||
roleAudience: nil,
|
||||
},
|
||||
wantRa: []string{"projID"},
|
||||
wantRequestedRoles: []string{"foo", "bar"},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotRa, gotRequestedRoles := prepareRoles(context.Background(), tt.args.projectID, tt.args.scope, tt.args.roleAudience)
|
||||
assert.Equal(t, tt.wantRa, gotRa, "roleAudience")
|
||||
assert.Equal(t, tt.wantRequestedRoles, gotRequestedRoles, "requestedRoles")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_userInfoToOIDC(t *testing.T) {
|
||||
metadata := []query.UserMetadata{
|
||||
{
|
||||
Key: "key1",
|
||||
Value: []byte{1, 2, 3},
|
||||
},
|
||||
{
|
||||
Key: "key2",
|
||||
Value: []byte{4, 5, 6},
|
||||
},
|
||||
}
|
||||
organization := &query.UserInfoOrg{
|
||||
ID: "orgID",
|
||||
Name: "orgName",
|
||||
PrimaryDomain: "orgDomain",
|
||||
}
|
||||
humanUserInfo := &query.OIDCUserInfo{
|
||||
User: &query.User{
|
||||
ID: "human1",
|
||||
CreationDate: time.Unix(123, 456),
|
||||
ChangeDate: time.Unix(567, 890),
|
||||
ResourceOwner: "orgID",
|
||||
Sequence: 22,
|
||||
State: domain.UserStateActive,
|
||||
Type: domain.UserTypeHuman,
|
||||
Username: "username",
|
||||
LoginNames: []string{"foo", "bar"},
|
||||
PreferredLoginName: "foo",
|
||||
Human: &query.Human{
|
||||
FirstName: "user",
|
||||
LastName: "name",
|
||||
NickName: "foobar",
|
||||
DisplayName: "xxx",
|
||||
AvatarKey: "picture.png",
|
||||
PreferredLanguage: language.Dutch,
|
||||
Gender: domain.GenderDiverse,
|
||||
Email: "foo@bar.com",
|
||||
IsEmailVerified: true,
|
||||
Phone: "+31123456789",
|
||||
IsPhoneVerified: true,
|
||||
},
|
||||
},
|
||||
Metadata: metadata,
|
||||
Org: organization,
|
||||
UserGrants: []query.UserGrant{
|
||||
{
|
||||
ID: "ug1",
|
||||
CreationDate: time.Unix(444, 444),
|
||||
ChangeDate: time.Unix(555, 555),
|
||||
Sequence: 55,
|
||||
Roles: []string{"role1", "role2"},
|
||||
GrantID: "grantID",
|
||||
State: domain.UserGrantStateActive,
|
||||
UserID: "human1",
|
||||
Username: "username",
|
||||
ResourceOwner: "orgID",
|
||||
ProjectID: "project1",
|
||||
OrgName: "orgName",
|
||||
OrgPrimaryDomain: "orgDomain",
|
||||
ProjectName: "projectName",
|
||||
UserResourceOwner: "org1",
|
||||
},
|
||||
},
|
||||
}
|
||||
machineUserInfo := &query.OIDCUserInfo{
|
||||
User: &query.User{
|
||||
ID: "machine1",
|
||||
CreationDate: time.Unix(123, 456),
|
||||
ChangeDate: time.Unix(567, 890),
|
||||
ResourceOwner: "orgID",
|
||||
Sequence: 23,
|
||||
State: domain.UserStateActive,
|
||||
Type: domain.UserTypeMachine,
|
||||
Username: "machine",
|
||||
PreferredLoginName: "meanMachine",
|
||||
Machine: &query.Machine{
|
||||
Name: "machine",
|
||||
Description: "I'm a robot",
|
||||
},
|
||||
},
|
||||
Org: organization,
|
||||
UserGrants: []query.UserGrant{
|
||||
{
|
||||
ID: "ug1",
|
||||
CreationDate: time.Unix(444, 444),
|
||||
ChangeDate: time.Unix(555, 555),
|
||||
Sequence: 55,
|
||||
Roles: []string{"role1", "role2"},
|
||||
GrantID: "grantID",
|
||||
State: domain.UserGrantStateActive,
|
||||
UserID: "human1",
|
||||
Username: "username",
|
||||
ResourceOwner: "orgID",
|
||||
ProjectID: "project1",
|
||||
OrgName: "orgName",
|
||||
OrgPrimaryDomain: "orgDomain",
|
||||
ProjectName: "projectName",
|
||||
UserResourceOwner: "org1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
type args struct {
|
||||
projectID string
|
||||
user *query.OIDCUserInfo
|
||||
scope []string
|
||||
roleAudience []string
|
||||
requestedRoles []string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *oidc.UserInfo
|
||||
}{
|
||||
{
|
||||
name: "human, empty",
|
||||
args: args{
|
||||
projectID: "project1",
|
||||
user: humanUserInfo,
|
||||
},
|
||||
want: &oidc.UserInfo{},
|
||||
},
|
||||
{
|
||||
name: "machine, empty",
|
||||
args: args{
|
||||
projectID: "project1",
|
||||
user: machineUserInfo,
|
||||
},
|
||||
want: &oidc.UserInfo{},
|
||||
},
|
||||
{
|
||||
name: "human, scope openid",
|
||||
args: args{
|
||||
projectID: "project1",
|
||||
user: humanUserInfo,
|
||||
scope: []string{oidc.ScopeOpenID},
|
||||
},
|
||||
want: &oidc.UserInfo{
|
||||
Subject: "human1",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "machine, scope openid",
|
||||
args: args{
|
||||
projectID: "project1",
|
||||
user: machineUserInfo,
|
||||
scope: []string{oidc.ScopeOpenID},
|
||||
},
|
||||
want: &oidc.UserInfo{
|
||||
Subject: "machine1",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "human, scope email",
|
||||
args: args{
|
||||
projectID: "project1",
|
||||
user: humanUserInfo,
|
||||
scope: []string{oidc.ScopeEmail},
|
||||
},
|
||||
want: &oidc.UserInfo{
|
||||
UserInfoEmail: oidc.UserInfoEmail{
|
||||
Email: "foo@bar.com",
|
||||
EmailVerified: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "machine, scope email",
|
||||
args: args{
|
||||
projectID: "project1",
|
||||
user: machineUserInfo,
|
||||
scope: []string{oidc.ScopeEmail},
|
||||
},
|
||||
want: &oidc.UserInfo{
|
||||
UserInfoEmail: oidc.UserInfoEmail{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "human, scope profile",
|
||||
args: args{
|
||||
projectID: "project1",
|
||||
user: humanUserInfo,
|
||||
scope: []string{oidc.ScopeProfile},
|
||||
},
|
||||
want: &oidc.UserInfo{
|
||||
UserInfoProfile: oidc.UserInfoProfile{
|
||||
Name: "xxx",
|
||||
GivenName: "user",
|
||||
FamilyName: "name",
|
||||
Nickname: "foobar",
|
||||
Picture: "https://foo.com/assets/orgID/picture.png",
|
||||
Gender: "diverse",
|
||||
Locale: oidc.NewLocale(language.Dutch),
|
||||
UpdatedAt: oidc.FromTime(time.Unix(567, 890)),
|
||||
PreferredUsername: "foo",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "machine, scope profile",
|
||||
args: args{
|
||||
projectID: "project1",
|
||||
user: machineUserInfo,
|
||||
scope: []string{oidc.ScopeProfile},
|
||||
},
|
||||
want: &oidc.UserInfo{
|
||||
UserInfoProfile: oidc.UserInfoProfile{
|
||||
Name: "machine",
|
||||
UpdatedAt: oidc.FromTime(time.Unix(567, 890)),
|
||||
PreferredUsername: "meanMachine",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "human, scope phone",
|
||||
args: args{
|
||||
projectID: "project1",
|
||||
user: humanUserInfo,
|
||||
scope: []string{oidc.ScopePhone},
|
||||
},
|
||||
want: &oidc.UserInfo{
|
||||
UserInfoPhone: oidc.UserInfoPhone{
|
||||
PhoneNumber: "+31123456789",
|
||||
PhoneNumberVerified: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "machine, scope phone",
|
||||
args: args{
|
||||
projectID: "project1",
|
||||
user: machineUserInfo,
|
||||
scope: []string{oidc.ScopePhone},
|
||||
},
|
||||
want: &oidc.UserInfo{
|
||||
UserInfoPhone: oidc.UserInfoPhone{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "human, scope metadata",
|
||||
args: args{
|
||||
projectID: "project1",
|
||||
user: humanUserInfo,
|
||||
scope: []string{ScopeUserMetaData},
|
||||
},
|
||||
want: &oidc.UserInfo{
|
||||
Claims: map[string]any{
|
||||
ClaimUserMetaData: map[string]string{
|
||||
"key1": base64.RawURLEncoding.EncodeToString([]byte{1, 2, 3}),
|
||||
"key2": base64.RawURLEncoding.EncodeToString([]byte{4, 5, 6}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "machine, scope metadata, none found",
|
||||
args: args{
|
||||
projectID: "project1",
|
||||
user: machineUserInfo,
|
||||
scope: []string{ScopeUserMetaData},
|
||||
},
|
||||
want: &oidc.UserInfo{},
|
||||
},
|
||||
{
|
||||
name: "machine, scope resource owner",
|
||||
args: args{
|
||||
projectID: "project1",
|
||||
user: machineUserInfo,
|
||||
scope: []string{ScopeResourceOwner},
|
||||
},
|
||||
want: &oidc.UserInfo{
|
||||
Claims: map[string]any{
|
||||
ClaimResourceOwner + "id": "orgID",
|
||||
ClaimResourceOwner + "name": "orgName",
|
||||
ClaimResourceOwner + "primary_domain": "orgDomain",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "human, scope org primary domain prefix",
|
||||
args: args{
|
||||
projectID: "project1",
|
||||
user: humanUserInfo,
|
||||
scope: []string{domain.OrgDomainPrimaryScope + "foo.com"},
|
||||
},
|
||||
want: &oidc.UserInfo{
|
||||
Claims: map[string]any{
|
||||
domain.OrgDomainPrimaryClaim: "foo.com",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "machine, scope org id",
|
||||
args: args{
|
||||
projectID: "project1",
|
||||
user: machineUserInfo,
|
||||
scope: []string{domain.OrgIDScope + "orgID"},
|
||||
},
|
||||
want: &oidc.UserInfo{
|
||||
Claims: map[string]any{
|
||||
domain.OrgIDClaim: "orgID",
|
||||
ClaimResourceOwner + "id": "orgID",
|
||||
ClaimResourceOwner + "name": "orgName",
|
||||
ClaimResourceOwner + "primary_domain": "orgDomain",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "human, roleAudience",
|
||||
args: args{
|
||||
projectID: "project1",
|
||||
user: humanUserInfo,
|
||||
roleAudience: []string{"project1"},
|
||||
},
|
||||
want: &oidc.UserInfo{
|
||||
Claims: map[string]any{
|
||||
ClaimProjectRoles: projectRoles{
|
||||
"role1": {"orgID": "orgDomain"},
|
||||
"role2": {"orgID": "orgDomain"},
|
||||
},
|
||||
fmt.Sprintf(ClaimProjectRolesFormat, "project1"): projectRoles{
|
||||
"role1": {"orgID": "orgDomain"},
|
||||
"role2": {"orgID": "orgDomain"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "human, requested roles",
|
||||
args: args{
|
||||
projectID: "project1",
|
||||
user: humanUserInfo,
|
||||
roleAudience: []string{"project1"},
|
||||
requestedRoles: []string{"role2"},
|
||||
},
|
||||
want: &oidc.UserInfo{
|
||||
Claims: map[string]any{
|
||||
ClaimProjectRoles: projectRoles{
|
||||
"role2": {"orgID": "orgDomain"},
|
||||
},
|
||||
fmt.Sprintf(ClaimProjectRolesFormat, "project1"): projectRoles{
|
||||
"role2": {"orgID": "orgDomain"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assetPrefix := "https://foo.com/assets"
|
||||
got := userInfoToOIDC(tt.args.projectID, tt.args.user, tt.args.scope, tt.args.roleAudience, tt.args.requestedRoles, assetPrefix)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user