mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-12 04:07:31 +00:00
feat: add saml custom attribute action and translations (#6341)
* feat: add saml custom attribute action and translations * chore: update saml dependency * fix: apply suggestions from code review Co-authored-by: Livio Spring <livio.a@gmail.com> * fix: custom attribute action with variadic parameter * docs: add customize saml response docs * docs: update docs/docs/apis/actions/customize-samlresponse.md Co-authored-by: Livio Spring <livio.a@gmail.com> * docs: update docs/docs/apis/actions/customize-samlresponse.md Co-authored-by: Livio Spring <livio.a@gmail.com> --------- Co-authored-by: Livio Spring <livio.a@gmail.com>
This commit is contained in:
@@ -19,6 +19,8 @@ func FlowTypeToDomain(flowType string) domain.FlowType {
|
||||
return domain.FlowTypeCustomiseToken
|
||||
case domain.FlowTypeInternalAuthentication.ID():
|
||||
return domain.FlowTypeInternalAuthentication
|
||||
case domain.FlowTypeCustomizeSAMLResponse.ID():
|
||||
return domain.FlowTypeCustomizeSAMLResponse
|
||||
default:
|
||||
return domain.FlowTypeUnspecified
|
||||
}
|
||||
@@ -47,6 +49,8 @@ func TriggerTypeToDomain(triggerType string) domain.TriggerType {
|
||||
return domain.TriggerTypePreAccessTokenCreation
|
||||
case domain.TriggerTypePreUserinfoCreation.ID():
|
||||
return domain.TriggerTypePreUserinfoCreation
|
||||
case domain.TriggerTypePreSAMLResponseCreation.ID():
|
||||
return domain.TriggerTypePreSAMLResponseCreation
|
||||
default:
|
||||
return domain.TriggerTypeUnspecified
|
||||
}
|
||||
|
@@ -18,6 +18,7 @@ func (s *Server) ListFlowTypes(ctx context.Context, _ *mgmt_pb.ListFlowTypesRequ
|
||||
action_grpc.FlowTypeToPb(domain.FlowTypeExternalAuthentication),
|
||||
action_grpc.FlowTypeToPb(domain.FlowTypeCustomiseToken),
|
||||
action_grpc.FlowTypeToPb(domain.FlowTypeInternalAuthentication),
|
||||
action_grpc.FlowTypeToPb(domain.FlowTypeCustomizeSAMLResponse),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
@@ -386,7 +386,7 @@ func (o *OPStorage) setUserinfo(ctx context.Context, userInfo *oidc.UserInfo, us
|
||||
}
|
||||
o.setUserInfoRoleClaims(userInfo, projectRoles)
|
||||
|
||||
return o.userinfoFlows(ctx, user.ResourceOwner, userGrants, userInfo)
|
||||
return o.userinfoFlows(ctx, user, userGrants, userInfo)
|
||||
}
|
||||
|
||||
func (o *OPStorage) setUserInfoProfile(ctx context.Context, userInfo *oidc.UserInfo, user *query.User) {
|
||||
@@ -457,8 +457,8 @@ func (o *OPStorage) setUserInfoRoleClaims(userInfo *oidc.UserInfo, roles *projec
|
||||
}
|
||||
}
|
||||
|
||||
func (o *OPStorage) userinfoFlows(ctx context.Context, resourceOwner string, userGrants *query.UserGrants, userInfo *oidc.UserInfo) error {
|
||||
queriedActions, err := o.query.GetActiveActionsByFlowAndTriggerType(ctx, domain.FlowTypeCustomiseToken, domain.TriggerTypePreUserinfoCreation, resourceOwner, false)
|
||||
func (o *OPStorage) userinfoFlows(ctx context.Context, user *query.User, userGrants *query.UserGrants, userInfo *oidc.UserInfo) error {
|
||||
queriedActions, err := o.query.GetActiveActionsByFlowAndTriggerType(ctx, domain.FlowTypeCustomiseToken, domain.TriggerTypePreUserinfoCreation, user.ResourceOwner, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -468,17 +468,13 @@ func (o *OPStorage) userinfoFlows(ctx context.Context, resourceOwner string, use
|
||||
actions.SetFields("claims", userinfoClaims(userInfo)),
|
||||
actions.SetFields("getUser", func(c *actions.FieldConfig) interface{} {
|
||||
return func(call goja.FunctionCall) goja.Value {
|
||||
user, err := o.query.GetUserByID(ctx, true, userInfo.Subject, false)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return object.UserFromQuery(c, user)
|
||||
}
|
||||
}),
|
||||
actions.SetFields("user",
|
||||
actions.SetFields("getMetadata", func(c *actions.FieldConfig) interface{} {
|
||||
return func(goja.FunctionCall) goja.Value {
|
||||
resourceOwnerQuery, err := query.NewUserMetadataResourceOwnerSearchQuery(resourceOwner)
|
||||
resourceOwnerQuery, err := query.NewUserMetadataResourceOwnerSearchQuery(user.ResourceOwner)
|
||||
if err != nil {
|
||||
logging.WithError(err).Debug("unable to create search query")
|
||||
panic(err)
|
||||
@@ -552,7 +548,7 @@ func (o *OPStorage) userinfoFlows(ctx context.Context, resourceOwner string, use
|
||||
Key: key,
|
||||
Value: value,
|
||||
}
|
||||
if _, err = o.command.SetUserMetadata(ctx, metadata, userInfo.Subject, resourceOwner); err != nil {
|
||||
if _, err = o.command.SetUserMetadata(ctx, metadata, userInfo.Subject, user.ResourceOwner); err != nil {
|
||||
logging.WithError(err).Info("unable to set md in action")
|
||||
panic(err)
|
||||
}
|
||||
@@ -665,10 +661,6 @@ func (o *OPStorage) privateClaimsFlows(ctx context.Context, userID string, userG
|
||||
}),
|
||||
actions.SetFields("getUser", func(c *actions.FieldConfig) interface{} {
|
||||
return func(call goja.FunctionCall) goja.Value {
|
||||
user, err := o.query.GetUserByID(ctx, true, userID, false)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return object.UserFromQuery(c, user)
|
||||
}
|
||||
}),
|
||||
@@ -807,7 +799,7 @@ func (o *OPStorage) assertRoles(ctx context.Context, userID, applicationID strin
|
||||
}
|
||||
return grants, roles, nil
|
||||
}
|
||||
// now specific roles were requested, so convert any grants into roles
|
||||
// no specific roles were requested, so convert any grants into roles
|
||||
for _, grant := range grants.UserGrants {
|
||||
for _, role := range grant.Roles {
|
||||
roles.Add(grant.ProjectID, role, grant.ResourceOwner, grant.OrgPrimaryDomain, grant.ProjectID == projectID)
|
||||
|
@@ -2,14 +2,19 @@ package saml
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/dop251/goja"
|
||||
"github.com/zitadel/logging"
|
||||
"github.com/zitadel/saml/pkg/provider"
|
||||
"github.com/zitadel/saml/pkg/provider/key"
|
||||
"github.com/zitadel/saml/pkg/provider/models"
|
||||
"github.com/zitadel/saml/pkg/provider/serviceprovider"
|
||||
"github.com/zitadel/saml/pkg/provider/xml/samlp"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/actions"
|
||||
"github.com/zitadel/zitadel/internal/actions/object"
|
||||
"github.com/zitadel/zitadel/internal/api/http/middleware"
|
||||
"github.com/zitadel/zitadel/internal/auth/repository"
|
||||
"github.com/zitadel/zitadel/internal/command"
|
||||
@@ -124,7 +129,7 @@ func (p *Storage) AuthRequestByID(ctx context.Context, id string) (_ models.Auth
|
||||
return AuthRequestFromBusiness(resp)
|
||||
}
|
||||
|
||||
func (p *Storage) SetUserinfoWithUserID(ctx context.Context, userinfo models.AttributeSetter, userID string, attributes []int) (err error) {
|
||||
func (p *Storage) SetUserinfoWithUserID(ctx context.Context, applicationID string, userinfo models.AttributeSetter, userID string, attributes []int) (err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
user, err := p.query.GetUserByID(ctx, true, userID, false)
|
||||
@@ -132,7 +137,17 @@ func (p *Storage) SetUserinfoWithUserID(ctx context.Context, userinfo models.Att
|
||||
return err
|
||||
}
|
||||
|
||||
setUserinfo(user, userinfo, attributes)
|
||||
userGrants, err := p.getGrants(ctx, userID, applicationID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
customAttributes, err := p.getCustomAttributes(ctx, user, userGrants)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
setUserinfo(user, userinfo, attributes, customAttributes)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -149,11 +164,14 @@ func (p *Storage) SetUserinfoWithLoginName(ctx context.Context, userinfo models.
|
||||
return err
|
||||
}
|
||||
|
||||
setUserinfo(user, userinfo, attributes)
|
||||
setUserinfo(user, userinfo, attributes, map[string]*customAttribute{})
|
||||
return nil
|
||||
}
|
||||
|
||||
func setUserinfo(user *query.User, userinfo models.AttributeSetter, attributes []int) {
|
||||
func setUserinfo(user *query.User, userinfo models.AttributeSetter, attributes []int, customAttributes map[string]*customAttribute) {
|
||||
for name, attr := range customAttributes {
|
||||
userinfo.SetCustomAttribute(name, "", attr.nameFormat, attr.attributeValue)
|
||||
}
|
||||
if len(attributes) == 0 {
|
||||
userinfo.SetUsername(user.PreferredLoginName)
|
||||
userinfo.SetUserID(user.ID)
|
||||
@@ -191,3 +209,139 @@ func setUserinfo(user *query.User, userinfo models.AttributeSetter, attributes [
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Storage) getCustomAttributes(ctx context.Context, user *query.User, userGrants *query.UserGrants) (map[string]*customAttribute, error) {
|
||||
customAttributes := make(map[string]*customAttribute, 0)
|
||||
queriedActions, err := p.query.GetActiveActionsByFlowAndTriggerType(ctx, domain.FlowTypeCustomizeSAMLResponse, domain.TriggerTypePreSAMLResponseCreation, user.ResourceOwner, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ctxFields := actions.SetContextFields(
|
||||
actions.SetFields("v1",
|
||||
actions.SetFields("getUser", func(c *actions.FieldConfig) interface{} {
|
||||
return func(call goja.FunctionCall) goja.Value {
|
||||
return object.UserFromQuery(c, user)
|
||||
}
|
||||
}),
|
||||
actions.SetFields("user",
|
||||
actions.SetFields("getMetadata", func(c *actions.FieldConfig) interface{} {
|
||||
return func(goja.FunctionCall) goja.Value {
|
||||
resourceOwnerQuery, err := query.NewUserMetadataResourceOwnerSearchQuery(user.ResourceOwner)
|
||||
if err != nil {
|
||||
logging.WithError(err).Debug("unable to create search query")
|
||||
panic(err)
|
||||
}
|
||||
metadata, err := p.query.SearchUserMetadata(
|
||||
ctx,
|
||||
true,
|
||||
user.ID,
|
||||
&query.UserMetadataSearchQueries{Queries: []query.SearchQuery{resourceOwnerQuery}},
|
||||
false,
|
||||
)
|
||||
if err != nil {
|
||||
logging.WithError(err).Info("unable to get md in action")
|
||||
panic(err)
|
||||
}
|
||||
return object.UserMetadataListFromQuery(c, metadata)
|
||||
}
|
||||
}),
|
||||
actions.SetFields("grants", func(c *actions.FieldConfig) interface{} {
|
||||
return object.UserGrantsFromQuery(c, userGrants)
|
||||
}),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
for _, action := range queriedActions {
|
||||
actionCtx, cancel := context.WithTimeout(ctx, action.Timeout())
|
||||
|
||||
apiFields := actions.WithAPIFields(
|
||||
actions.SetFields("v1",
|
||||
actions.SetFields("attributes",
|
||||
actions.SetFields("setCustomAttribute", func(name string, nameFormat string, attributeValue ...string) {
|
||||
if _, ok := customAttributes[name]; !ok {
|
||||
customAttributes = appendCustomAttribute(customAttributes, name, nameFormat, attributeValue)
|
||||
return
|
||||
}
|
||||
}),
|
||||
),
|
||||
actions.SetFields("user",
|
||||
actions.SetFields("setMetadata", func(call goja.FunctionCall) {
|
||||
if len(call.Arguments) != 2 {
|
||||
panic("exactly 2 (key, value) arguments expected")
|
||||
}
|
||||
key := call.Arguments[0].Export().(string)
|
||||
val := call.Arguments[1].Export()
|
||||
|
||||
value, err := json.Marshal(val)
|
||||
if err != nil {
|
||||
logging.WithError(err).Debug("unable to marshal")
|
||||
panic(err)
|
||||
}
|
||||
|
||||
metadata := &domain.Metadata{
|
||||
Key: key,
|
||||
Value: value,
|
||||
}
|
||||
if _, err = p.command.SetUserMetadata(ctx, metadata, user.ID, user.ResourceOwner); err != nil {
|
||||
logging.WithError(err).Info("unable to set md in action")
|
||||
panic(err)
|
||||
}
|
||||
}),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
err = actions.Run(
|
||||
actionCtx,
|
||||
ctxFields,
|
||||
apiFields,
|
||||
action.Script,
|
||||
action.Name,
|
||||
append(actions.ActionToOptions(action), actions.WithHTTP(actionCtx))...,
|
||||
)
|
||||
cancel()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return customAttributes, nil
|
||||
}
|
||||
|
||||
func (p *Storage) getGrants(ctx context.Context, userID, applicationID string) (*query.UserGrants, error) {
|
||||
projectID, err := p.query.ProjectIDFromClientID(ctx, applicationID, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
projectQuery, err := query.NewUserGrantProjectIDSearchQuery(projectID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userIDQuery, err := query.NewUserGrantUserIDSearchQuery(userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return p.query.UserGrants(ctx, &query.UserGrantsQueries{
|
||||
Queries: []query.SearchQuery{
|
||||
projectQuery,
|
||||
userIDQuery,
|
||||
},
|
||||
}, true, false)
|
||||
}
|
||||
|
||||
type customAttribute struct {
|
||||
nameFormat string
|
||||
attributeValue []string
|
||||
}
|
||||
|
||||
func appendCustomAttribute(customAttributes map[string]*customAttribute, name string, nameFormat string, attributeValue []string) map[string]*customAttribute {
|
||||
if customAttributes == nil {
|
||||
customAttributes = make(map[string]*customAttribute)
|
||||
}
|
||||
customAttributes[name] = &customAttribute{
|
||||
nameFormat: nameFormat,
|
||||
attributeValue: attributeValue,
|
||||
}
|
||||
return customAttributes
|
||||
}
|
||||
|
Reference in New Issue
Block a user